Spaces:
Running
Running
| """ | |
| Functions for fine-tuning Gemma models | |
| """ | |
| import os | |
| import time | |
| import json | |
| import threading | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from datetime import datetime | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForLanguageModeling | |
| ) | |
| from peft import get_peft_model, LoraConfig, TaskType | |
| from data_processing import create_train_val_split, format_for_training | |
| from model_utils import load_model | |
| from datasets import Dataset | |
| # Global variable to store training state | |
| _TRAINING_STATE = None | |
| class TrainingThread(threading.Thread): | |
| """Thread class for running training in the background.""" | |
| def __init__(self, model_name, dataset, params): | |
| threading.Thread.__init__(self) | |
| self.model_name = model_name | |
| self.dataset = dataset | |
| self.params = params | |
| self.stop_flag = False | |
| self.daemon = True # Thread will exit when main program exits | |
| def run(self): | |
| """Run the training process.""" | |
| try: | |
| # Initialize training state | |
| global _TRAINING_STATE | |
| _TRAINING_STATE = { | |
| "status": "initializing", | |
| "current_epoch": 0, | |
| "current_step": 0, | |
| "total_steps": 0, | |
| "elapsed_time": 0, | |
| "loss_plot": None, | |
| "eval_plot": None, | |
| "log": "", | |
| "samples": None, | |
| "error": None | |
| } | |
| # Create output directory | |
| output_dir = os.path.join("outputs", datetime.now().strftime("%Y%m%d_%H%M%S")) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Load the model and tokenizer | |
| model, tokenizer = load_model(self.model_name) | |
| # Apply LoRA configuration | |
| lora_config = LoraConfig( | |
| r=self.params.get("lora_r", 16), | |
| lora_alpha=self.params.get("lora_alpha", 32), | |
| lora_dropout=self.params.get("lora_dropout", 0.05), | |
| bias="none", | |
| task_type=TaskType.CAUSAL_LM | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| # Split dataset into train and validation | |
| train_data, val_data = create_train_val_split(self.dataset) | |
| # Format data for training | |
| max_length = self.params.get("max_seq_length", 512) | |
| train_formatted = format_for_training(train_data, tokenizer, max_length) | |
| val_formatted = format_for_training(val_data, tokenizer, max_length) | |
| # Convert to HF Datasets | |
| train_dataset = Dataset.from_dict(train_formatted) | |
| val_dataset = Dataset.from_dict(val_formatted) | |
| # Create data collator | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, | |
| mlm=False | |
| ) | |
| # Set up training arguments | |
| batch_size = self.params.get("batch_size", 4) | |
| gradient_accumulation_steps = self.params.get("gradient_accumulation_steps", 1) | |
| num_epochs = self.params.get("num_epochs", 3) | |
| # Calculate total steps | |
| train_steps = len(train_dataset) // batch_size // gradient_accumulation_steps * num_epochs | |
| _TRAINING_STATE["total_steps"] = train_steps | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| learning_rate=self.params.get("learning_rate", 2e-5), | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| num_train_epochs=num_epochs, | |
| weight_decay=self.params.get("weight_decay", 0.01), | |
| warmup_steps=self.params.get("warmup_steps", 100), | |
| logging_dir=os.path.join(output_dir, "logs"), | |
| logging_steps=10, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| load_best_model_at_end=True, | |
| report_to="none" # Disable wandb, tensorboard, etc. | |
| ) | |
| # Custom callback for UI updates | |
| class UICallback: | |
| def __init__(self, thread): | |
| self.thread = thread | |
| self.start_time = time.time() | |
| self.losses = [] | |
| self.eval_metrics = [] | |
| self.log_buffer = "" | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if self.thread.stop_flag: | |
| control.should_training_stop = True | |
| _TRAINING_STATE["status"] = "stopped" | |
| return | |
| if logs is None: | |
| return | |
| # Update training state | |
| _TRAINING_STATE["elapsed_time"] = time.time() - self.start_time | |
| # Handle training logs | |
| if "loss" in logs: | |
| _TRAINING_STATE["current_step"] = state.global_step | |
| loss = logs["loss"] | |
| self.losses.append((state.global_step, loss)) | |
| # Update loss plot | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| steps, losses = zip(*self.losses) | |
| ax.plot(steps, losses) | |
| ax.set_xlabel("Steps") | |
| ax.set_ylabel("Loss") | |
| ax.set_title("Training Loss") | |
| ax.grid(True) | |
| _TRAINING_STATE["loss_plot"] = fig | |
| # Update log | |
| log_entry = f"Step {state.global_step}: loss={loss:.4f}\n" | |
| self.log_buffer += log_entry | |
| _TRAINING_STATE["log"] = self.log_buffer | |
| # Handle evaluation logs | |
| if "eval_loss" in logs: | |
| _TRAINING_STATE["current_epoch"] = state.epoch | |
| eval_loss = logs["eval_loss"] | |
| self.eval_metrics.append((state.epoch, eval_loss)) | |
| # Update eval plot | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| epochs, metrics = zip(*self.eval_metrics) | |
| ax.plot(epochs, metrics) | |
| ax.set_xlabel("Epochs") | |
| ax.set_ylabel("Evaluation Loss") | |
| ax.set_title("Validation Loss") | |
| ax.grid(True) | |
| _TRAINING_STATE["eval_plot"] = fig | |
| # Generate sample outputs for visualization | |
| sample_outputs = self.generate_samples(model, tokenizer) | |
| _TRAINING_STATE["samples"] = sample_outputs | |
| # Update log | |
| log_entry = f"Epoch {state.epoch}: eval_loss={eval_loss:.4f}\n" | |
| self.log_buffer += log_entry | |
| _TRAINING_STATE["log"] = self.log_buffer | |
| def generate_samples(self, model, tokenizer, num_samples=3): | |
| """Generate sample outputs from the current model.""" | |
| # Get random samples from validation set | |
| val_indices = np.random.choice(len(val_data), min(num_samples, len(val_data)), replace=False) | |
| samples = [val_data[i] for i in val_indices] | |
| results = [] | |
| for sample in samples: | |
| prompt = sample["prompt"] | |
| reference = sample["completion"] | |
| # Generate text | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=100, | |
| temperature=0.7, | |
| num_return_sequences=1 | |
| ) | |
| generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the prompt from the generated text | |
| if generated.startswith(prompt): | |
| generated = generated[len(prompt):].strip() | |
| results.append({ | |
| "Prompt": prompt, | |
| "Generated Text": generated, | |
| "Reference": reference | |
| }) | |
| return pd.DataFrame(results) | |
| # Create trainer | |
| ui_callback = UICallback(self) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| data_collator=data_collator, | |
| callbacks=[ui_callback] | |
| ) | |
| # Update training state | |
| _TRAINING_STATE["status"] = "training" | |
| # Start training | |
| trainer.train() | |
| # Save final model | |
| trainer.save_model(os.path.join(output_dir, "final")) | |
| tokenizer.save_pretrained(os.path.join(output_dir, "final")) | |
| # Update training state | |
| _TRAINING_STATE["status"] = "completed" | |
| _TRAINING_STATE["fine_tuned_model_path"] = os.path.join(output_dir, "final") | |
| except Exception as e: | |
| # Update training state with error | |
| _TRAINING_STATE["status"] = "error" | |
| _TRAINING_STATE["error"] = str(e) | |
| print(f"Training error: {str(e)}") | |
| def stop(self): | |
| """Signal the thread to stop training.""" | |
| self.stop_flag = True | |
| def start_fine_tuning(model_name, dataset, params): | |
| """ | |
| Start the fine-tuning process in a background thread. | |
| Args: | |
| model_name: Name of the model to fine-tune | |
| dataset: Processed dataset | |
| params: Training parameters | |
| Returns: | |
| TrainingThread object | |
| """ | |
| thread = TrainingThread(model_name, dataset, params) | |
| thread.start() | |
| return thread | |
| def load_training_state(): | |
| """ | |
| Get the current training state. | |
| Returns: | |
| Dictionary with training state information | |
| """ | |
| return _TRAINING_STATE | |