Spaces:
Runtime error
Runtime error
Commit
·
351c0ef
1
Parent(s):
f94b561
add: docs for train classifier
Browse files- docs/train_classifier.md +3 -0
- guardrails_genie/train_classifier.py +56 -3
- mkdocs.yml +1 -0
docs/train_classifier.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Train Classifier
|
| 2 |
+
|
| 3 |
+
::: guardrails_genie.train_classifier
|
guardrails_genie/train_classifier.py
CHANGED
|
@@ -16,6 +16,22 @@ import wandb
|
|
| 16 |
|
| 17 |
|
| 18 |
class StreamlitProgressbarCallback(TrainerCallback):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def __init__(self, *args, **kwargs):
|
| 21 |
super().__init__(*args, **kwargs)
|
|
@@ -42,6 +58,8 @@ def train_binary_classifier(
|
|
| 42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
| 43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
| 44 |
prompt_column_name: str = "prompt",
|
|
|
|
|
|
|
| 45 |
learning_rate: float = 1e-5,
|
| 46 |
batch_size: int = 16,
|
| 47 |
num_epochs: int = 2,
|
|
@@ -49,6 +67,44 @@ def train_binary_classifier(
|
|
| 49 |
save_steps: int = 1000,
|
| 50 |
streamlit_mode: bool = False,
|
| 51 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
| 53 |
if streamlit_mode:
|
| 54 |
st.markdown(
|
|
@@ -69,9 +125,6 @@ def train_binary_classifier(
|
|
| 69 |
predictions = np.argmax(predictions, axis=1)
|
| 70 |
return accuracy.compute(predictions=predictions, references=labels)
|
| 71 |
|
| 72 |
-
id2label = {0: "SAFE", 1: "INJECTION"}
|
| 73 |
-
label2id = {"SAFE": 0, "INJECTION": 1}
|
| 74 |
-
|
| 75 |
model = AutoModelForSequenceClassification.from_pretrained(
|
| 76 |
model_name,
|
| 77 |
num_labels=2,
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class StreamlitProgressbarCallback(TrainerCallback):
|
| 19 |
+
"""
|
| 20 |
+
StreamlitProgressbarCallback is a custom callback for the Hugging Face Trainer
|
| 21 |
+
that integrates a progress bar into a Streamlit application. This class updates
|
| 22 |
+
the progress bar at each training step, providing real-time feedback on the
|
| 23 |
+
training process within the Streamlit interface.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
progress_bar (streamlit.delta_generator.DeltaGenerator): A Streamlit progress
|
| 27 |
+
bar object initialized to 0 with the text "Training".
|
| 28 |
+
|
| 29 |
+
Methods:
|
| 30 |
+
on_step_begin(args, state, control, **kwargs):
|
| 31 |
+
Updates the progress bar at the beginning of each training step. The progress
|
| 32 |
+
is calculated as the percentage of completed steps out of the total steps.
|
| 33 |
+
The progress bar text is updated to show the current step and the total steps.
|
| 34 |
+
"""
|
| 35 |
|
| 36 |
def __init__(self, *args, **kwargs):
|
| 37 |
super().__init__(*args, **kwargs)
|
|
|
|
| 58 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
| 59 |
model_name: str = "distilbert/distilbert-base-uncased",
|
| 60 |
prompt_column_name: str = "prompt",
|
| 61 |
+
id2label: dict[int, str] = {0: "SAFE", 1: "INJECTION"},
|
| 62 |
+
label2id: dict[str, int] = {"SAFE": 0, "INJECTION": 1},
|
| 63 |
learning_rate: float = 1e-5,
|
| 64 |
batch_size: int = 16,
|
| 65 |
num_epochs: int = 2,
|
|
|
|
| 67 |
save_steps: int = 1000,
|
| 68 |
streamlit_mode: bool = False,
|
| 69 |
):
|
| 70 |
+
"""
|
| 71 |
+
Trains a binary classifier using a specified dataset and model architecture.
|
| 72 |
+
|
| 73 |
+
This function sets up and trains a binary sequence classification model using
|
| 74 |
+
the Hugging Face Transformers library. It integrates with Weights & Biases for
|
| 75 |
+
experiment tracking and optionally displays a progress bar in a Streamlit app.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
project_name (str): The name of the Weights & Biases project.
|
| 79 |
+
entity_name (str): The Weights & Biases entity (user or team).
|
| 80 |
+
run_name (str): The name of the Weights & Biases run.
|
| 81 |
+
dataset_repo (str, optional): The Hugging Face dataset repository to load.
|
| 82 |
+
Defaults to "geekyrakshit/prompt-injection-dataset".
|
| 83 |
+
model_name (str, optional): The pre-trained model to use. Defaults to
|
| 84 |
+
"distilbert/distilbert-base-uncased".
|
| 85 |
+
prompt_column_name (str, optional): The column name in the dataset containing
|
| 86 |
+
the text prompts. Defaults to "prompt".
|
| 87 |
+
id2label (dict[int, str], optional): Mapping from label IDs to label names.
|
| 88 |
+
Defaults to {0: "SAFE", 1: "INJECTION"}.
|
| 89 |
+
label2id (dict[str, int], optional): Mapping from label names to label IDs.
|
| 90 |
+
Defaults to {"SAFE": 0, "INJECTION": 1}.
|
| 91 |
+
learning_rate (float, optional): The learning rate for training. Defaults to 1e-5.
|
| 92 |
+
batch_size (int, optional): The batch size for training and evaluation.
|
| 93 |
+
Defaults to 16.
|
| 94 |
+
num_epochs (int, optional): The number of training epochs. Defaults to 2.
|
| 95 |
+
weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.01.
|
| 96 |
+
save_steps (int, optional): The number of steps between model checkpoints.
|
| 97 |
+
Defaults to 1000.
|
| 98 |
+
streamlit_mode (bool, optional): If True, integrates with Streamlit to display
|
| 99 |
+
a progress bar. Defaults to False.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
dict: The output of the training process, including metrics and model state.
|
| 103 |
+
|
| 104 |
+
Raises:
|
| 105 |
+
Exception: If an error occurs during training, the exception is raised after
|
| 106 |
+
ensuring Weights & Biases run is finished.
|
| 107 |
+
"""
|
| 108 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
| 109 |
if streamlit_mode:
|
| 110 |
st.markdown(
|
|
|
|
| 125 |
predictions = np.argmax(predictions, axis=1)
|
| 126 |
return accuracy.compute(predictions=predictions, references=labels)
|
| 127 |
|
|
|
|
|
|
|
|
|
|
| 128 |
model = AutoModelForSequenceClassification.from_pretrained(
|
| 129 |
model_name,
|
| 130 |
num_labels=2,
|
mkdocs.yml
CHANGED
|
@@ -68,6 +68,7 @@ nav:
|
|
| 68 |
- LLM: 'llm.md'
|
| 69 |
- Metrics: 'metrics.md'
|
| 70 |
- RegexModel: 'regex_model.md'
|
|
|
|
| 71 |
- Utils: 'utils.md'
|
| 72 |
|
| 73 |
repo_url: https://github.com/soumik12345/guardrails-genie
|
|
|
|
| 68 |
- LLM: 'llm.md'
|
| 69 |
- Metrics: 'metrics.md'
|
| 70 |
- RegexModel: 'regex_model.md'
|
| 71 |
+
- Train Classifier: 'train_classifier.md'
|
| 72 |
- Utils: 'utils.md'
|
| 73 |
|
| 74 |
repo_url: https://github.com/soumik12345/guardrails-genie
|