# Import environment setup before any other imports from env_setup import setup_environment setup_environment() import gradio as gr import os from model_utils import load_model, get_available_models from data_processing import process_dataset, validate_dataset from fine_tuning import start_fine_tuning, load_training_state import tempfile CSS = """ .feedback-div { padding: 10px; margin-bottom: 10px; border-radius: 5px; } .success { background-color: #d4edda; color: #155724; border: 1px solid #c3e6cb; } .error { background-color: #f8d7da; color: #721c24; border: 1px solid #f5c6cb; } .info { background-color: #d1ecf1; color: #0c5460; border: 1px solid #bee5eb; } """ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: # Store state across tabs state = gr.State({ "dataset_path": None, "processed_dataset": None, "model_name": None, "model_instance": None, "training_params": None, "fine_tuned_model_path": None, "training_logs": [] }) with gr.Sidebar(): gr.Markdown("# Gemma Fine-Tuning UI") gr.Markdown("Sign in with your Hugging Face account to use the Nebius API for inference and model access.") button = gr.LoginButton("Sign in") gr.Markdown("## Navigation") with gr.Tab("Introduction"): gr.Markdown(""" # Welcome to Gemma Fine-Tuning UI This application allows you to fine-tune Google's Gemma models on your own datasets with a user-friendly interface. ## Features: - Upload and preprocess your datasets in various formats (CSV, JSONL, TXT) - Configure model hyperparameters for optimal performance - Visualize training progress in real-time - Export your fine-tuned model in different formats ## Getting Started: 1. Navigate to the **Dataset Upload** tab to prepare your data 2. Configure your model and hyperparameters in the **Model Configuration** tab 3. Start and monitor training in the **Training** tab 4. Export your fine-tuned model in the **Export Model** tab For more details, check the Documentation tab. """) with gr.Tab("Dataset Upload"): gr.Markdown("## Upload and prepare your dataset for fine-tuning") with gr.Row(): with gr.Column(): dataset_file = gr.File( label="Upload Dataset File (CSV, JSONL, or TXT)", file_types=["csv", "jsonl", "json", "txt"] ) data_format = gr.Radio( ["CSV", "JSONL", "Plain Text"], label="Data Format", value="CSV" ) with gr.Accordion("CSV Options", open=False): csv_prompt_col = gr.Textbox(label="Prompt Column Name", value="prompt") csv_completion_col = gr.Textbox(label="Completion Column Name", value="completion") csv_separator = gr.Textbox(label="Column Separator", value=",") with gr.Accordion("JSONL Options", open=False): jsonl_prompt_key = gr.Textbox(label="Prompt Key", value="prompt") jsonl_completion_key = gr.Textbox(label="Completion Key", value="completion") with gr.Accordion("Text Options", open=False): text_separator = gr.Textbox( label="Prompt/Completion Separator", value="###", info="Symbol or text that separates prompts from completions" ) process_btn = gr.Button("Process Dataset", variant="primary") with gr.Column(): dataset_info = gr.JSON(label="Dataset Information", visible=True) preview_df = gr.Dataframe(label="Data Preview", wrap=True) dataset_feedback = gr.Markdown( "", elem_classes=["feedback-div"] ) def process_dataset_handler( file, data_format, csv_prompt, csv_completion, csv_sep, jsonl_prompt, jsonl_completion, text_sep, current_state ): if file is None: return ( current_state, None, gr.update(value="⚠️ Please upload a file first", elem_classes=["feedback-div", "error"]), None ) try: # Create a temporary file to store the uploaded content temp_dir = tempfile.mkdtemp() file_path = os.path.join(temp_dir, file.name) # Save the uploaded file to the temporary location with open(file_path, "wb") as f: f.write(file.read()) # Prepare format-specific options options = { "format": data_format.lower(), "csv_prompt_col": csv_prompt, "csv_completion_col": csv_completion, "csv_separator": csv_sep, "jsonl_prompt_key": jsonl_prompt, "jsonl_completion_key": jsonl_completion, "text_separator": text_sep } # Validate the dataset is_valid, message = validate_dataset(file_path, options) if not is_valid: return ( current_state, None, gr.update(value=f"⚠️ {message}", elem_classes=["feedback-div", "error"]), None ) # Process the dataset processed_data, stats, preview = process_dataset(file_path, options) # Update state current_state = current_state.copy() current_state["dataset_path"] = file_path current_state["processed_dataset"] = processed_data return ( current_state, stats, gr.update(value="✅ Dataset processed successfully", elem_classes=["feedback-div", "success"]), preview ) except Exception as e: return ( current_state, None, gr.update(value=f"⚠️ Error processing dataset: {str(e)}", elem_classes=["feedback-div", "error"]), None ) process_btn.click( process_dataset_handler, inputs=[ dataset_file, data_format, csv_prompt_col, csv_completion_col, csv_separator, jsonl_prompt_key, jsonl_completion_key, text_separator, state ], outputs=[state, dataset_info, dataset_feedback, preview_df] ) with gr.Tab("Model Configuration"): gr.Markdown("## Select a model and configure hyperparameters") with gr.Row(): with gr.Column(): model_name = gr.Dropdown( choices=get_available_models(), label="Select Base Model", value="google/gemma-2-2b-it" ) with gr.Accordion("Training Parameters", open=True): learning_rate = gr.Slider( minimum=1e-6, maximum=1e-3, value=2e-5, step=1e-6, label="Learning Rate", info="Controls how quickly the model adapts to the training data" ) batch_size = gr.Slider( minimum=1, maximum=32, value=4, step=1, label="Batch Size", info="Number of samples processed before model weights are updated" ) num_epochs = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Epochs", info="Number of complete passes through the training dataset" ) max_seq_length = gr.Slider( minimum=128, maximum=2048, value=512, step=64, label="Max Sequence Length", info="Maximum length of input sequences" ) with gr.Accordion("Advanced Options", open=False): gradient_accumulation_steps = gr.Slider( minimum=1, maximum=16, value=1, step=1, label="Gradient Accumulation Steps", info="Accumulate gradients over multiple batches to simulate larger batch size" ) warmup_steps = gr.Slider( minimum=0, maximum=500, value=100, step=10, label="Warmup Steps", info="Number of steps for learning rate warmup" ) weight_decay = gr.Slider( minimum=0, maximum=0.1, value=0.01, step=0.001, label="Weight Decay", info="L2 regularization factor to prevent overfitting" ) lora_r = gr.Slider( minimum=1, maximum=64, value=16, step=1, label="LoRA Rank (r)", info="Rank of LoRA adaptors (lower value = smaller model)" ) lora_alpha = gr.Slider( minimum=1, maximum=64, value=32, step=1, label="LoRA Alpha", info="LoRA scaling factor (higher = stronger adaptation)" ) lora_dropout = gr.Slider( minimum=0, maximum=0.5, value=0.05, step=0.01, label="LoRA Dropout", info="Dropout probability for LoRA layers" ) save_config_btn = gr.Button("Save Configuration", variant="primary") with gr.Column(): config_info = gr.JSON(label="Current Configuration") config_feedback = gr.Markdown( "", elem_classes=["feedback-div"] ) def save_config_handler( model, lr, bs, epochs, seq_len, grad_accum, warmup, weight_decay, lora_r, lora_alpha, lora_dropout, current_state ): # Check if dataset is processed if current_state["processed_dataset"] is None: return ( current_state, None, gr.update(value="⚠️ Please process a dataset first in the Dataset Upload tab", elem_classes=["feedback-div", "error"]) ) config = { "model_name": model, "learning_rate": lr, "batch_size": bs, "num_epochs": epochs, "max_seq_length": seq_len, "gradient_accumulation_steps": grad_accum, "warmup_steps": warmup, "weight_decay": weight_decay, "lora_r": lora_r, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout } # Update state current_state = current_state.copy() current_state["model_name"] = model current_state["training_params"] = config return ( current_state, config, gr.update(value="✅ Configuration saved successfully", elem_classes=["feedback-div", "success"]) ) save_config_btn.click( save_config_handler, inputs=[ model_name, learning_rate, batch_size, num_epochs, max_seq_length, gradient_accumulation_steps, warmup_steps, weight_decay, lora_r, lora_alpha, lora_dropout, state ], outputs=[state, config_info, config_feedback] ) with gr.Tab("Training"): gr.Markdown("## Train your model and monitor progress") with gr.Row(): with gr.Column(scale=1): start_btn = gr.Button("Start Training", variant="primary", interactive=True) stop_btn = gr.Button("Stop Training", variant="stop", interactive=False) with gr.Accordion("Training Status", open=True): status = gr.Markdown("Not started", elem_classes=["feedback-div", "info"]) progress = gr.Slider( minimum=0, maximum=100, value=0, label="Training Progress", interactive=False ) current_epoch = gr.Number(label="Current Epoch", value=0, interactive=False) current_step = gr.Number(label="Current Step", value=0, interactive=False) elapsed_time = gr.Textbox(label="Elapsed Time", value="00:00:00", interactive=False) with gr.Column(scale=2): with gr.Row(): with gr.Column(): loss_plot = gr.Plot(label="Training Loss") with gr.Column(): eval_plot = gr.Plot(label="Evaluation Metrics") training_log = gr.Textbox( label="Training Log", interactive=False, lines=10 ) with gr.Accordion("Sample Generations", open=True): sample_outputs = gr.Dataframe( headers=["Prompt", "Generated Text", "Reference"], label="Sample Model Outputs", wrap=True ) # Timer for UI updates ui_update_interval = gr.Number(value=1, visible=False) def start_training_handler(current_state): # Validate state if current_state["processed_dataset"] is None: return ( current_state, gr.update(value="⚠️ Please process a dataset first", elem_classes=["feedback-div", "error"]), gr.update(interactive=True), gr.update(interactive=False) ) if current_state["training_params"] is None: return ( current_state, gr.update(value="⚠️ Please configure training parameters first", elem_classes=["feedback-div", "error"]), gr.update(interactive=True), gr.update(interactive=False) ) # Start training in a background thread try: train_thread = start_fine_tuning( model_name=current_state["model_name"], dataset=current_state["processed_dataset"], params=current_state["training_params"] ) current_state = current_state.copy() current_state["training_thread"] = train_thread return ( current_state, gr.update(value="✅ Training started", elem_classes=["feedback-div", "success"]), gr.update(interactive=False), gr.update(interactive=True) ) except Exception as e: return ( current_state, gr.update(value=f"⚠️ Error starting training: {str(e)}", elem_classes=["feedback-div", "error"]), gr.update(interactive=True), gr.update(interactive=False) ) def stop_training_handler(current_state): if "training_thread" in current_state and current_state["training_thread"] is not None: # Signal the training thread to stop current_state["training_thread"].stop() current_state = current_state.copy() current_state["training_thread"] = None return ( current_state, gr.update(value="⚠️ Training stopped by user", elem_classes=["feedback-div", "error"]), gr.update(interactive=True), gr.update(interactive=False) ) else: return ( current_state, gr.update(value="⚠️ No active training to stop", elem_classes=["feedback-div", "error"]), gr.update(interactive=True), gr.update(interactive=False) ) def update_training_ui(): training_state = load_training_state() if training_state is None: return ( 0, 0, 0, "00:00:00", None, None, "", None, gr.update(value="Not started", elem_classes=["feedback-div", "info"]) ) # Calculate progress percentage total_steps = training_state["total_steps"] current_step = training_state["current_step"] progress_pct = (current_step / total_steps * 100) if total_steps > 0 else 0 # Format elapsed time hours, remainder = divmod(training_state["elapsed_time"], 3600) minutes, seconds = divmod(remainder, 60) time_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}" # Update status message if training_state["status"] == "completed": status_msg = gr.update(value="✅ Training completed successfully", elem_classes=["feedback-div", "success"]) elif training_state["status"] == "error": status_msg = gr.update(value=f"⚠️ Training error: {training_state['error']}", elem_classes=["feedback-div", "error"]) elif training_state["status"] == "stopped": status_msg = gr.update(value="⚠️ Training stopped by user", elem_classes=["feedback-div", "error"]) else: status_msg = gr.update(value="⏳ Training in progress...", elem_classes=["feedback-div", "info"]) return ( progress_pct, training_state["current_epoch"], current_step, time_str, training_state["loss_plot"], training_state["eval_plot"], training_state["log"], training_state["samples"], status_msg ) start_btn.click( start_training_handler, inputs=[state], outputs=[state, status, start_btn, stop_btn] ) stop_btn.click( stop_training_handler, inputs=[state], outputs=[state, status, start_btn, stop_btn] ) # Remove problematic JavaScript loading approach # Create a simple manual refresh button for compatibility manual_refresh = gr.Button("Refresh Status", visible=True) manual_refresh.click( update_training_ui, inputs=None, outputs=[ progress, current_epoch, current_step, elapsed_time, loss_plot, eval_plot, training_log, sample_outputs, status ] ) # Add auto-refresh functionality with HTML component auto_refresh = gr.HTML("""

Auto-refreshing status every 2 seconds

""") # Initial UI update demo.load( update_training_ui, inputs=None, outputs=[ progress, current_epoch, current_step, elapsed_time, loss_plot, eval_plot, training_log, sample_outputs, status ] ) with gr.Tab("Export Model"): gr.Markdown("## Export your fine-tuned model") with gr.Row(): with gr.Column(): export_format = gr.Radio( ["PyTorch", "GGUF", "Safetensors"], label="Export Format", value="PyTorch" ) quantization = gr.Dropdown( ["None", "int8", "int4"], label="Quantization (GGUF only)", value="None", interactive=True ) model_name_input = gr.Textbox( label="Model Name", placeholder="my-fine-tuned-gemma", value="my-fine-tuned-gemma" ) output_dir = gr.Textbox( label="Output Directory", placeholder="Path to save the exported model", value="./exports" ) export_btn = gr.Button("Export Model", variant="primary") with gr.Column(): export_info = gr.JSON(label="Export Information", visible=False) export_status = gr.Markdown( "", elem_classes=["feedback-div"] ) # Fix: Remove 'visible' parameter which is not supported in this Gradio version export_progress = gr.Progress() def export_model_handler(current_state, format, quant, name, out_dir): if current_state.get("fine_tuned_model_path") is None: return ( gr.update(value="⚠️ No fine-tuned model available. Please complete training first.", elem_classes=["feedback-div", "error"]), None ) try: # Actual export would be implemented in another function export_path = os.path.join(out_dir, name) os.makedirs(export_path, exist_ok=True) export_info = { "format": format, "quantization": quant if format == "GGUF" else "None", "model_name": name, "export_path": export_path, "model_size": "0.5 GB", # This would be calculated during actual export "export_time": "00:01:23" # This would be measured during actual export } return ( gr.update(value=f"✅ Model exported successfully to {export_path}", elem_classes=["feedback-div", "success"]), export_info ) except Exception as e: return ( gr.update(value=f"⚠️ Error exporting model: {str(e)}", elem_classes=["feedback-div", "error"]), None ) export_btn.click( export_model_handler, inputs=[state, export_format, quantization, model_name_input, output_dir], # Update outputs list to remove reference to progress visibility outputs=[export_status, export_info] ) with gr.Tab("Documentation"): gr.Markdown(""" # Gemma Fine-Tuning Documentation ## Supported Models This application supports fine-tuning the following Gemma models: - google/gemma-2-2b-it - google/gemma-2-9b-it - google/gemma-2-27b-it ## Dataset Format Your dataset should follow one of these formats: ### CSV ``` prompt,completion "What is the capital of France?","The capital of France is Paris." "How does photosynthesis work?","Photosynthesis is the process..." ``` ### JSONL ``` {"prompt": "What is the capital of France?", "completion": "The capital of France is Paris."} {"prompt": "How does photosynthesis work?", "completion": "Photosynthesis is the process..."} ``` ### Plain Text ``` What is the capital of France? ### The capital of France is Paris. ### How does photosynthesis work? ### Photosynthesis is the process... ``` ## Fine-Tuning Parameters ### Basic Parameters - **Learning Rate**: Controls how quickly the model adapts to the training data. Typical values range from 1e-5 to 5e-5. - **Batch Size**: Number of samples processed before model weights are updated. Higher values require more memory. - **Number of Epochs**: Number of complete passes through the training dataset. More epochs can lead to better results but may cause overfitting. - **Max Sequence Length**: Maximum length of input sequences. Longer sequences require more memory. ### Advanced Parameters - **Gradient Accumulation Steps**: Accumulate gradients over multiple batches to simulate larger batch size. - **Warmup Steps**: Number of steps for learning rate warmup. Helps stabilize training in the early phases. - **Weight Decay**: L2 regularization factor to prevent overfitting. - **LoRA Parameters**: Controls the behavior of LoRA (Low-Rank Adaptation), a parameter-efficient fine-tuning technique. ## Export Formats - **PyTorch**: Standard PyTorch model format (.pt or .bin files with model architecture). - **GGUF**: Compact format optimized for efficient inference (especially with llama.cpp). - **Safetensors**: Safe format for storing tensors, preventing arbitrary code execution. ## Quantization Quantization reduces model size and increases inference speed at the cost of some accuracy: - **None**: No quantization, full precision (usually FP16 or BF16). - **int8**: 8-bit integer quantization, good balance of speed and accuracy. - **int4**: 4-bit integer quantization, fastest but may reduce accuracy more significantly. """) demo.launch()