""" Functions for fine-tuning Gemma models """ import os import time import json import threading import torch import numpy as np import matplotlib.pyplot as plt import pandas as pd from datetime import datetime from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from peft import get_peft_model, LoraConfig, TaskType from data_processing import create_train_val_split, format_for_training from model_utils import load_model from datasets import Dataset # Global variable to store training state _TRAINING_STATE = None class TrainingThread(threading.Thread): """Thread class for running training in the background.""" def __init__(self, model_name, dataset, params): threading.Thread.__init__(self) self.model_name = model_name self.dataset = dataset self.params = params self.stop_flag = False self.daemon = True # Thread will exit when main program exits def run(self): """Run the training process.""" try: # Initialize training state global _TRAINING_STATE _TRAINING_STATE = { "status": "initializing", "current_epoch": 0, "current_step": 0, "total_steps": 0, "elapsed_time": 0, "loss_plot": None, "eval_plot": None, "log": "", "samples": None, "error": None } # Create output directory output_dir = os.path.join("outputs", datetime.now().strftime("%Y%m%d_%H%M%S")) os.makedirs(output_dir, exist_ok=True) # Load the model and tokenizer model, tokenizer = load_model(self.model_name) # Apply LoRA configuration lora_config = LoraConfig( r=self.params.get("lora_r", 16), lora_alpha=self.params.get("lora_alpha", 32), lora_dropout=self.params.get("lora_dropout", 0.05), bias="none", task_type=TaskType.CAUSAL_LM ) model = get_peft_model(model, lora_config) # Split dataset into train and validation train_data, val_data = create_train_val_split(self.dataset) # Format data for training max_length = self.params.get("max_seq_length", 512) train_formatted = format_for_training(train_data, tokenizer, max_length) val_formatted = format_for_training(val_data, tokenizer, max_length) # Convert to HF Datasets train_dataset = Dataset.from_dict(train_formatted) val_dataset = Dataset.from_dict(val_formatted) # Create data collator data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) # Set up training arguments batch_size = self.params.get("batch_size", 4) gradient_accumulation_steps = self.params.get("gradient_accumulation_steps", 1) num_epochs = self.params.get("num_epochs", 3) # Calculate total steps train_steps = len(train_dataset) // batch_size // gradient_accumulation_steps * num_epochs _TRAINING_STATE["total_steps"] = train_steps # Training arguments training_args = TrainingArguments( output_dir=output_dir, learning_rate=self.params.get("learning_rate", 2e-5), per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, num_train_epochs=num_epochs, weight_decay=self.params.get("weight_decay", 0.01), warmup_steps=self.params.get("warmup_steps", 100), logging_dir=os.path.join(output_dir, "logs"), logging_steps=10, evaluation_strategy="epoch", save_strategy="epoch", save_total_limit=2, load_best_model_at_end=True, report_to="none" # Disable wandb, tensorboard, etc. ) # Custom callback for UI updates class UICallback: def __init__(self, thread): self.thread = thread self.start_time = time.time() self.losses = [] self.eval_metrics = [] self.log_buffer = "" def on_log(self, args, state, control, logs=None, **kwargs): if self.thread.stop_flag: control.should_training_stop = True _TRAINING_STATE["status"] = "stopped" return if logs is None: return # Update training state _TRAINING_STATE["elapsed_time"] = time.time() - self.start_time # Handle training logs if "loss" in logs: _TRAINING_STATE["current_step"] = state.global_step loss = logs["loss"] self.losses.append((state.global_step, loss)) # Update loss plot fig, ax = plt.subplots(figsize=(10, 6)) steps, losses = zip(*self.losses) ax.plot(steps, losses) ax.set_xlabel("Steps") ax.set_ylabel("Loss") ax.set_title("Training Loss") ax.grid(True) _TRAINING_STATE["loss_plot"] = fig # Update log log_entry = f"Step {state.global_step}: loss={loss:.4f}\n" self.log_buffer += log_entry _TRAINING_STATE["log"] = self.log_buffer # Handle evaluation logs if "eval_loss" in logs: _TRAINING_STATE["current_epoch"] = state.epoch eval_loss = logs["eval_loss"] self.eval_metrics.append((state.epoch, eval_loss)) # Update eval plot fig, ax = plt.subplots(figsize=(10, 6)) epochs, metrics = zip(*self.eval_metrics) ax.plot(epochs, metrics) ax.set_xlabel("Epochs") ax.set_ylabel("Evaluation Loss") ax.set_title("Validation Loss") ax.grid(True) _TRAINING_STATE["eval_plot"] = fig # Generate sample outputs for visualization sample_outputs = self.generate_samples(model, tokenizer) _TRAINING_STATE["samples"] = sample_outputs # Update log log_entry = f"Epoch {state.epoch}: eval_loss={eval_loss:.4f}\n" self.log_buffer += log_entry _TRAINING_STATE["log"] = self.log_buffer def generate_samples(self, model, tokenizer, num_samples=3): """Generate sample outputs from the current model.""" # Get random samples from validation set val_indices = np.random.choice(len(val_data), min(num_samples, len(val_data)), replace=False) samples = [val_data[i] for i in val_indices] results = [] for sample in samples: prompt = sample["prompt"] reference = sample["completion"] # Generate text inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, temperature=0.7, num_return_sequences=1 ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the prompt from the generated text if generated.startswith(prompt): generated = generated[len(prompt):].strip() results.append({ "Prompt": prompt, "Generated Text": generated, "Reference": reference }) return pd.DataFrame(results) # Create trainer ui_callback = UICallback(self) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator, callbacks=[ui_callback] ) # Update training state _TRAINING_STATE["status"] = "training" # Start training trainer.train() # Save final model trainer.save_model(os.path.join(output_dir, "final")) tokenizer.save_pretrained(os.path.join(output_dir, "final")) # Update training state _TRAINING_STATE["status"] = "completed" _TRAINING_STATE["fine_tuned_model_path"] = os.path.join(output_dir, "final") except Exception as e: # Update training state with error _TRAINING_STATE["status"] = "error" _TRAINING_STATE["error"] = str(e) print(f"Training error: {str(e)}") def stop(self): """Signal the thread to stop training.""" self.stop_flag = True def start_fine_tuning(model_name, dataset, params): """ Start the fine-tuning process in a background thread. Args: model_name: Name of the model to fine-tune dataset: Processed dataset params: Training parameters Returns: TrainingThread object """ thread = TrainingThread(model_name, dataset, params) thread.start() return thread def load_training_state(): """ Get the current training state. Returns: Dictionary with training state information """ return _TRAINING_STATE