Spaces:
Runtime error
Runtime error
Commit
·
883a576
1
Parent(s):
32d5d0c
add: docs for LlamaGuardFineTuner
Browse files
guardrails_genie/train/llama_guard.py
CHANGED
|
@@ -24,6 +24,19 @@ class DatasetArgs(BaseModel):
|
|
| 24 |
|
| 25 |
|
| 26 |
class LlamaGuardFineTuner:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def __init__(
|
| 28 |
self, wandb_project: str, wandb_entity: str, streamlit_mode: bool = False
|
| 29 |
):
|
|
@@ -32,6 +45,24 @@ class LlamaGuardFineTuner:
|
|
| 32 |
self.streamlit_mode = streamlit_mode
|
| 33 |
|
| 34 |
def load_dataset(self, dataset_args: DatasetArgs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
dataset = load_dataset(dataset_args.dataset_address)
|
| 36 |
self.train_dataset = (
|
| 37 |
dataset["train"]
|
|
@@ -47,6 +78,22 @@ class LlamaGuardFineTuner:
|
|
| 47 |
)
|
| 48 |
|
| 49 |
def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 51 |
self.model_name = model_name
|
| 52 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
@@ -55,6 +102,19 @@ class LlamaGuardFineTuner:
|
|
| 55 |
)
|
| 56 |
|
| 57 |
def show_dataset_sample(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
if self.streamlit_mode:
|
| 59 |
st.markdown("### Train Dataset Sample")
|
| 60 |
st.dataframe(self.train_dataset.to_pandas().head())
|
|
@@ -189,6 +249,31 @@ class LlamaGuardFineTuner:
|
|
| 189 |
truncation: bool = True,
|
| 190 |
max_length: int = 512,
|
| 191 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
test_scores = self.evaluate_batch(
|
| 193 |
self.test_dataset["text"],
|
| 194 |
batch_size=batch_size,
|
|
@@ -217,6 +302,32 @@ class LlamaGuardFineTuner:
|
|
| 217 |
log_interval: int = 20,
|
| 218 |
save_interval: int = 1000,
|
| 219 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
os.makedirs("checkpoints", exist_ok=True)
|
| 221 |
wandb.init(
|
| 222 |
project=self.wandb_project,
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class LlamaGuardFineTuner:
|
| 27 |
+
"""
|
| 28 |
+
`LlamaGuardFineTuner` is a class designed to fine-tune and evaluate the
|
| 29 |
+
[Prompt Guard model by Meta LLama](meta-llama/Prompt-Guard-86M) for prompt
|
| 30 |
+
classification tasks, specifically for detecting prompt injection attacks. It
|
| 31 |
+
integrates with Weights & Biases for experiment tracking and optionally
|
| 32 |
+
displays progress in a Streamlit app.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
wandb_project (str): The name of the Weights & Biases project.
|
| 36 |
+
wandb_entity (str): The Weights & Biases entity (user or team).
|
| 37 |
+
streamlit_mode (bool): If True, integrates with Streamlit to display progress.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
def __init__(
|
| 41 |
self, wandb_project: str, wandb_entity: str, streamlit_mode: bool = False
|
| 42 |
):
|
|
|
|
| 45 |
self.streamlit_mode = streamlit_mode
|
| 46 |
|
| 47 |
def load_dataset(self, dataset_args: DatasetArgs):
|
| 48 |
+
"""
|
| 49 |
+
Loads the training and testing datasets based on the provided dataset arguments.
|
| 50 |
+
|
| 51 |
+
This function uses the `load_dataset` function from the `datasets` library to load
|
| 52 |
+
the dataset specified by the `dataset_address` attribute of the `dataset_args` parameter.
|
| 53 |
+
It then selects a subset of the training and testing datasets based on the specified
|
| 54 |
+
ranges in `train_dataset_range` and `test_dataset_range` attributes of `dataset_args`.
|
| 55 |
+
If the specified range is less than or equal to 0 or exceeds the length of the dataset,
|
| 56 |
+
the entire dataset is used.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
dataset_args (DatasetArgs): An instance of the `DatasetArgs` class containing
|
| 60 |
+
the dataset address and the ranges for training and testing datasets.
|
| 61 |
+
|
| 62 |
+
Attributes:
|
| 63 |
+
train_dataset: The selected training dataset.
|
| 64 |
+
test_dataset: The selected testing dataset.
|
| 65 |
+
"""
|
| 66 |
dataset = load_dataset(dataset_args.dataset_address)
|
| 67 |
self.train_dataset = (
|
| 68 |
dataset["train"]
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
|
| 81 |
+
"""
|
| 82 |
+
Loads the specified pre-trained model and tokenizer for sequence classification tasks.
|
| 83 |
+
|
| 84 |
+
This function sets the device to GPU if available, otherwise defaults to CPU. It then
|
| 85 |
+
loads the tokenizer and model from the Hugging Face model hub using the provided model name.
|
| 86 |
+
The model is moved to the specified device (GPU or CPU).
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
model_name (str): The name of the pre-trained model to load.
|
| 90 |
+
|
| 91 |
+
Attributes:
|
| 92 |
+
device (str): The device to run the model on, either "cuda" for GPU or "cpu".
|
| 93 |
+
model_name (str): The name of the loaded pre-trained model.
|
| 94 |
+
tokenizer (AutoTokenizer): The tokenizer associated with the pre-trained model.
|
| 95 |
+
model (AutoModelForSequenceClassification): The loaded pre-trained model for sequence classification.
|
| 96 |
+
"""
|
| 97 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 98 |
self.model_name = model_name
|
| 99 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
def show_dataset_sample(self):
|
| 105 |
+
"""
|
| 106 |
+
Displays a sample of the training and testing datasets using Streamlit.
|
| 107 |
+
|
| 108 |
+
This function checks if the `streamlit_mode` attribute is enabled. If it is,
|
| 109 |
+
it converts the training and testing datasets to pandas DataFrames and displays
|
| 110 |
+
the first few rows of each dataset using Streamlit's `dataframe` function. The
|
| 111 |
+
training dataset sample is displayed under the heading "Train Dataset Sample",
|
| 112 |
+
and the testing dataset sample is displayed under the heading "Test Dataset Sample".
|
| 113 |
+
|
| 114 |
+
Note:
|
| 115 |
+
This function requires the `streamlit` library to be installed and the
|
| 116 |
+
`streamlit_mode` attribute to be set to True.
|
| 117 |
+
"""
|
| 118 |
if self.streamlit_mode:
|
| 119 |
st.markdown("### Train Dataset Sample")
|
| 120 |
st.dataframe(self.train_dataset.to_pandas().head())
|
|
|
|
| 249 |
truncation: bool = True,
|
| 250 |
max_length: int = 512,
|
| 251 |
):
|
| 252 |
+
"""
|
| 253 |
+
Evaluates the fine-tuned model on the test dataset and visualizes the results.
|
| 254 |
+
|
| 255 |
+
This function evaluates the model by processing the test dataset in batches.
|
| 256 |
+
It computes the test scores using the `evaluate_batch` method, which takes
|
| 257 |
+
several parameters to control the evaluation process, such as batch size,
|
| 258 |
+
positive label, temperature, truncation, and maximum sequence length.
|
| 259 |
+
|
| 260 |
+
After obtaining the test scores, it visualizes the performance of the model
|
| 261 |
+
using two methods:
|
| 262 |
+
1. `visualize_roc_curve`: Plots the Receiver Operating Characteristic (ROC) curve
|
| 263 |
+
to show the trade-off between the true positive rate and false positive rate.
|
| 264 |
+
2. `visualize_score_distribution`: Plots the distribution of scores for positive
|
| 265 |
+
and negative examples to provide insights into the model's performance.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
batch_size (int, optional): The number of samples to process in each batch.
|
| 269 |
+
positive_label (int, optional): The label considered as positive for evaluation.
|
| 270 |
+
temperature (float, optional): The temperature parameter for scaling logits.
|
| 271 |
+
truncation (bool, optional): Whether to truncate sequences to the maximum length.
|
| 272 |
+
max_length (int, optional): The maximum length of sequences after truncation.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
list[float]: The test scores obtained from the evaluation.
|
| 276 |
+
"""
|
| 277 |
test_scores = self.evaluate_batch(
|
| 278 |
self.test_dataset["text"],
|
| 279 |
batch_size=batch_size,
|
|
|
|
| 302 |
log_interval: int = 20,
|
| 303 |
save_interval: int = 1000,
|
| 304 |
):
|
| 305 |
+
"""
|
| 306 |
+
Fine-tunes the pre-trained LlamaGuard model on the training dataset for a single epoch.
|
| 307 |
+
|
| 308 |
+
This function sets up and executes the training loop for the LlamaGuard model.
|
| 309 |
+
It initializes the Weights & Biases (wandb) logging, configures the model's
|
| 310 |
+
classifier layer to match the specified number of classes, and sets the model
|
| 311 |
+
to training mode. The function uses an AdamW optimizer to update the model
|
| 312 |
+
parameters based on the computed loss.
|
| 313 |
+
|
| 314 |
+
The training process involves iterating over the training dataset in batches,
|
| 315 |
+
computing the loss for each batch, and updating the model parameters. The
|
| 316 |
+
function logs the loss to wandb at specified intervals and optionally displays
|
| 317 |
+
a progress bar using Streamlit if `streamlit_mode` is enabled. Model checkpoints
|
| 318 |
+
are saved at specified intervals during training.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
batch_size (int, optional): The number of samples per batch during training.
|
| 322 |
+
lr (float, optional): The learning rate for the optimizer.
|
| 323 |
+
num_classes (int, optional): The number of output classes for the classifier.
|
| 324 |
+
log_interval (int, optional): The interval (in batches) at which to log the loss.
|
| 325 |
+
save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
|
| 326 |
+
|
| 327 |
+
Note:
|
| 328 |
+
This function requires the `wandb` and `streamlit` libraries to be installed
|
| 329 |
+
and configured appropriately.
|
| 330 |
+
"""
|
| 331 |
os.makedirs("checkpoints", exist_ok=True)
|
| 332 |
wandb.init(
|
| 333 |
project=self.wandb_project,
|