Spaces:
Runtime error
Runtime error
Commit
·
8382f82
1
Parent(s):
3a7ead3
update: training code
Browse files
guardrails_genie/train_classifier.py
CHANGED
|
@@ -42,10 +42,11 @@ def train_binary_classifier(
|
|
| 42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
| 43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
| 44 |
prompt_column_name: str = "prompt",
|
| 45 |
-
learning_rate: float =
|
| 46 |
batch_size: int = 16,
|
| 47 |
num_epochs: int = 2,
|
| 48 |
weight_decay: float = 0.01,
|
|
|
|
| 49 |
streamlit_mode: bool = False,
|
| 50 |
):
|
| 51 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
|
@@ -88,7 +89,8 @@ def train_binary_classifier(
|
|
| 88 |
num_train_epochs=num_epochs,
|
| 89 |
weight_decay=weight_decay,
|
| 90 |
eval_strategy="epoch",
|
| 91 |
-
save_strategy="
|
|
|
|
| 92 |
load_best_model_at_end=True,
|
| 93 |
push_to_hub=False,
|
| 94 |
report_to="wandb",
|
|
|
|
| 42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
| 43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
| 44 |
prompt_column_name: str = "prompt",
|
| 45 |
+
learning_rate: float = 1e-5,
|
| 46 |
batch_size: int = 16,
|
| 47 |
num_epochs: int = 2,
|
| 48 |
weight_decay: float = 0.01,
|
| 49 |
+
save_steps: int = 1000,
|
| 50 |
streamlit_mode: bool = False,
|
| 51 |
):
|
| 52 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
|
|
|
| 89 |
num_train_epochs=num_epochs,
|
| 90 |
weight_decay=weight_decay,
|
| 91 |
eval_strategy="epoch",
|
| 92 |
+
save_strategy="steps",
|
| 93 |
+
save_steps=save_steps,
|
| 94 |
load_best_model_at_end=True,
|
| 95 |
push_to_hub=False,
|
| 96 |
report_to="wandb",
|