Spaces:
Runtime error
Runtime error
Commit
·
177344c
1
Parent(s):
883a576
update: LlamaGuardFineTuner + corresponding docs
Browse files- docs/train/train_llama_guard.md +3 -0
- guardrails_genie/train/__init__.py +4 -0
- guardrails_genie/train/llama_guard.py +38 -5
- mkdocs.yml +1 -0
docs/train/train_llama_guard.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Train Llama Guard
|
| 2 |
+
|
| 3 |
+
::: guardrails_genie.train.llama_guard
|
guardrails_genie/train/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .train_classifier import train_binary_classifier
|
| 2 |
+
from .llama_guard import LlamaGuardFineTuner, DatasetArgs
|
| 3 |
+
|
| 4 |
+
__all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
|
guardrails_genie/train/llama_guard.py
CHANGED
|
@@ -30,6 +30,26 @@ class LlamaGuardFineTuner:
|
|
| 30 |
classification tasks, specifically for detecting prompt injection attacks. It
|
| 31 |
integrates with Weights & Biases for experiment tracking and optionally
|
| 32 |
displays progress in a Streamlit app.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
Args:
|
| 35 |
wandb_project (str): The name of the Weights & Biases project.
|
|
@@ -63,6 +83,7 @@ class LlamaGuardFineTuner:
|
|
| 63 |
train_dataset: The selected training dataset.
|
| 64 |
test_dataset: The selected testing dataset.
|
| 65 |
"""
|
|
|
|
| 66 |
dataset = load_dataset(dataset_args.dataset_address)
|
| 67 |
self.train_dataset = (
|
| 68 |
dataset["train"]
|
|
@@ -299,8 +320,8 @@ class LlamaGuardFineTuner:
|
|
| 299 |
batch_size: int = 32,
|
| 300 |
lr: float = 5e-6,
|
| 301 |
num_classes: int = 2,
|
| 302 |
-
log_interval: int =
|
| 303 |
-
save_interval: int =
|
| 304 |
):
|
| 305 |
"""
|
| 306 |
Fine-tunes the pre-trained LlamaGuard model on the training dataset for a single epoch.
|
|
@@ -332,13 +353,21 @@ class LlamaGuardFineTuner:
|
|
| 332 |
wandb.init(
|
| 333 |
project=self.wandb_project,
|
| 334 |
entity=self.wandb_entity,
|
| 335 |
-
name=f"{self.model_name}-{self.
|
| 336 |
job_type="fine-tune-llama-guard",
|
| 337 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
self.model.classifier = nn.Linear(
|
| 339 |
self.model.classifier.in_features, num_classes
|
| 340 |
)
|
| 341 |
self.model.num_labels = num_classes
|
|
|
|
| 342 |
self.model.train()
|
| 343 |
optimizer = optim.AdamW(self.model.parameters(), lr=lr)
|
| 344 |
data_loader = DataLoader(
|
|
@@ -367,8 +396,12 @@ class LlamaGuardFineTuner:
|
|
| 367 |
progress_percentage,
|
| 368 |
text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
|
| 369 |
)
|
| 370 |
-
if (i + 1) % save_interval == 0:
|
| 371 |
save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
|
| 372 |
-
wandb.log_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
wandb.finish()
|
| 374 |
shutil.rmtree("checkpoints")
|
|
|
|
| 30 |
classification tasks, specifically for detecting prompt injection attacks. It
|
| 31 |
integrates with Weights & Biases for experiment tracking and optionally
|
| 32 |
displays progress in a Streamlit app.
|
| 33 |
+
|
| 34 |
+
!!! example "Sample Usage"
|
| 35 |
+
```python
|
| 36 |
+
from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
|
| 37 |
+
|
| 38 |
+
fine_tuner = LlamaGuardFineTuner(
|
| 39 |
+
wandb_project="guardrails-genie",
|
| 40 |
+
wandb_entity="geekyrakshit",
|
| 41 |
+
streamlit_mode=False,
|
| 42 |
+
)
|
| 43 |
+
fine_tuner.load_dataset(
|
| 44 |
+
DatasetArgs(
|
| 45 |
+
dataset_address="wandb/synthetic-prompt-injections",
|
| 46 |
+
train_dataset_range=-1,
|
| 47 |
+
test_dataset_range=-1,
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
fine_tuner.load_model()
|
| 51 |
+
fine_tuner.train(save_interval=100)
|
| 52 |
+
```
|
| 53 |
|
| 54 |
Args:
|
| 55 |
wandb_project (str): The name of the Weights & Biases project.
|
|
|
|
| 83 |
train_dataset: The selected training dataset.
|
| 84 |
test_dataset: The selected testing dataset.
|
| 85 |
"""
|
| 86 |
+
self.dataset_args = dataset_args
|
| 87 |
dataset = load_dataset(dataset_args.dataset_address)
|
| 88 |
self.train_dataset = (
|
| 89 |
dataset["train"]
|
|
|
|
| 320 |
batch_size: int = 32,
|
| 321 |
lr: float = 5e-6,
|
| 322 |
num_classes: int = 2,
|
| 323 |
+
log_interval: int = 1,
|
| 324 |
+
save_interval: int = 50,
|
| 325 |
):
|
| 326 |
"""
|
| 327 |
Fine-tunes the pre-trained LlamaGuard model on the training dataset for a single epoch.
|
|
|
|
| 353 |
wandb.init(
|
| 354 |
project=self.wandb_project,
|
| 355 |
entity=self.wandb_entity,
|
| 356 |
+
name=f"{self.model_name}-{self.dataset_args.dataset_address.split('/')[-1]}",
|
| 357 |
job_type="fine-tune-llama-guard",
|
| 358 |
)
|
| 359 |
+
wandb.config.dataset_args = self.dataset_args.model_dump()
|
| 360 |
+
wandb.config.model_name = self.model_name
|
| 361 |
+
wandb.config.batch_size = batch_size
|
| 362 |
+
wandb.config.lr = lr
|
| 363 |
+
wandb.config.num_classes = num_classes
|
| 364 |
+
wandb.config.log_interval = log_interval
|
| 365 |
+
wandb.config.save_interval = save_interval
|
| 366 |
self.model.classifier = nn.Linear(
|
| 367 |
self.model.classifier.in_features, num_classes
|
| 368 |
)
|
| 369 |
self.model.num_labels = num_classes
|
| 370 |
+
self.model = self.model.to(self.device)
|
| 371 |
self.model.train()
|
| 372 |
optimizer = optim.AdamW(self.model.parameters(), lr=lr)
|
| 373 |
data_loader = DataLoader(
|
|
|
|
| 396 |
progress_percentage,
|
| 397 |
text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
|
| 398 |
)
|
| 399 |
+
if (i + 1) % save_interval == 0 or i + 1 == len(data_loader):
|
| 400 |
save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
|
| 401 |
+
wandb.log_model(
|
| 402 |
+
f"checkpoints/model-{i + 1}.safetensors",
|
| 403 |
+
name=f"{wandb.run.id}-model",
|
| 404 |
+
aliases=f"step-{i + 1}",
|
| 405 |
+
)
|
| 406 |
wandb.finish()
|
| 407 |
shutil.rmtree("checkpoints")
|
mkdocs.yml
CHANGED
|
@@ -80,6 +80,7 @@ nav:
|
|
| 80 |
- RegexModel: 'regex_model.md'
|
| 81 |
- Training:
|
| 82 |
- Train Classifier: 'train/train_classifier.md'
|
|
|
|
| 83 |
- Utils: 'utils.md'
|
| 84 |
|
| 85 |
repo_url: https://github.com/soumik12345/guardrails-genie
|
|
|
|
| 80 |
- RegexModel: 'regex_model.md'
|
| 81 |
- Training:
|
| 82 |
- Train Classifier: 'train/train_classifier.md'
|
| 83 |
+
- Train Llama Guard: 'train/train_llama_guard.md'
|
| 84 |
- Utils: 'utils.md'
|
| 85 |
|
| 86 |
repo_url: https://github.com/soumik12345/guardrails-genie
|