Spaces:
Runtime error
Runtime error
Commit
·
5706850
1
Parent(s):
780c9f0
update: LlamaGuardFineTuner.train
Browse files
guardrails_genie/train/llama_guard.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import streamlit as st
|
|
@@ -208,7 +209,15 @@ class LlamaGuardFineTuner:
|
|
| 208 |
)
|
| 209 |
return encodings.input_ids, encodings.attention_mask, labels
|
| 210 |
|
| 211 |
-
def train(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
wandb.init(
|
| 213 |
project=self.wandb_project,
|
| 214 |
entity=self.wandb_entity,
|
|
@@ -239,14 +248,16 @@ class LlamaGuardFineTuner:
|
|
| 239 |
optimizer.zero_grad()
|
| 240 |
loss.backward()
|
| 241 |
optimizer.step()
|
| 242 |
-
|
|
|
|
| 243 |
if progress_bar:
|
| 244 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
| 245 |
progress_bar.progress(
|
| 246 |
progress_percentage,
|
| 247 |
text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
|
| 248 |
)
|
| 249 |
-
|
| 250 |
-
|
|
|
|
| 251 |
wandb.finish()
|
| 252 |
-
|
|
|
|
| 1 |
import os
|
| 2 |
+
import shutil
|
| 3 |
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
import streamlit as st
|
|
|
|
| 209 |
)
|
| 210 |
return encodings.input_ids, encodings.attention_mask, labels
|
| 211 |
|
| 212 |
+
def train(
|
| 213 |
+
self,
|
| 214 |
+
batch_size: int = 32,
|
| 215 |
+
lr: float = 5e-6,
|
| 216 |
+
num_classes: int = 2,
|
| 217 |
+
log_interval: int = 20,
|
| 218 |
+
save_interval: int = 1000,
|
| 219 |
+
):
|
| 220 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 221 |
wandb.init(
|
| 222 |
project=self.wandb_project,
|
| 223 |
entity=self.wandb_entity,
|
|
|
|
| 248 |
optimizer.zero_grad()
|
| 249 |
loss.backward()
|
| 250 |
optimizer.step()
|
| 251 |
+
if (i + 1) % log_interval == 0:
|
| 252 |
+
wandb.log({"loss": loss.item()}, step=i + 1)
|
| 253 |
if progress_bar:
|
| 254 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
| 255 |
progress_bar.progress(
|
| 256 |
progress_percentage,
|
| 257 |
text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
|
| 258 |
)
|
| 259 |
+
if (i + 1) % save_interval == 0:
|
| 260 |
+
save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
|
| 261 |
+
wandb.log_model(f"checkpoints/model-{i + 1}.safetensors")
|
| 262 |
wandb.finish()
|
| 263 |
+
shutil.rmtree("checkpoints")
|