Spaces:
Running
Running
| # Import environment setup before any other imports | |
| from env_setup import setup_environment | |
| setup_environment() | |
| import gradio as gr | |
| import os | |
| from model_utils import load_model, get_available_models | |
| from data_processing import process_dataset, validate_dataset | |
| from fine_tuning import start_fine_tuning, load_training_state | |
| import tempfile | |
| CSS = """ | |
| .feedback-div { | |
| padding: 10px; | |
| margin-bottom: 10px; | |
| border-radius: 5px; | |
| } | |
| .success { | |
| background-color: #d4edda; | |
| color: #155724; | |
| border: 1px solid #c3e6cb; | |
| } | |
| .error { | |
| background-color: #f8d7da; | |
| color: #721c24; | |
| border: 1px solid #f5c6cb; | |
| } | |
| .info { | |
| background-color: #d1ecf1; | |
| color: #0c5460; | |
| border: 1px solid #bee5eb; | |
| } | |
| """ | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| # Store state across tabs | |
| state = gr.State({ | |
| "dataset_path": None, | |
| "processed_dataset": None, | |
| "model_name": None, | |
| "model_instance": None, | |
| "training_params": None, | |
| "fine_tuned_model_path": None, | |
| "training_logs": [] | |
| }) | |
| with gr.Sidebar(): | |
| gr.Markdown("# Gemma Fine-Tuning UI") | |
| gr.Markdown("Sign in with your Hugging Face account to use the Nebius API for inference and model access.") | |
| button = gr.LoginButton("Sign in") | |
| gr.Markdown("## Navigation") | |
| with gr.Tab("Introduction"): | |
| gr.Markdown(""" | |
| # Welcome to Gemma Fine-Tuning UI | |
| This application allows you to fine-tune Google's Gemma models on your own datasets with a user-friendly interface. | |
| ## Features: | |
| - Upload and preprocess your datasets in various formats (CSV, JSONL, TXT) | |
| - Configure model hyperparameters for optimal performance | |
| - Visualize training progress in real-time | |
| - Export your fine-tuned model in different formats | |
| ## Getting Started: | |
| 1. Navigate to the **Dataset Upload** tab to prepare your data | |
| 2. Configure your model and hyperparameters in the **Model Configuration** tab | |
| 3. Start and monitor training in the **Training** tab | |
| 4. Export your fine-tuned model in the **Export Model** tab | |
| For more details, check the Documentation tab. | |
| """) | |
| with gr.Tab("Dataset Upload"): | |
| gr.Markdown("## Upload and prepare your dataset for fine-tuning") | |
| with gr.Row(): | |
| with gr.Column(): | |
| dataset_file = gr.File( | |
| label="Upload Dataset File (CSV, JSONL, or TXT)", | |
| file_types=["csv", "jsonl", "json", "txt"] | |
| ) | |
| data_format = gr.Radio( | |
| ["CSV", "JSONL", "Plain Text"], | |
| label="Data Format", | |
| value="CSV" | |
| ) | |
| with gr.Accordion("CSV Options", open=False): | |
| csv_prompt_col = gr.Textbox(label="Prompt Column Name", value="prompt") | |
| csv_completion_col = gr.Textbox(label="Completion Column Name", value="completion") | |
| csv_separator = gr.Textbox(label="Column Separator", value=",") | |
| with gr.Accordion("JSONL Options", open=False): | |
| jsonl_prompt_key = gr.Textbox(label="Prompt Key", value="prompt") | |
| jsonl_completion_key = gr.Textbox(label="Completion Key", value="completion") | |
| with gr.Accordion("Text Options", open=False): | |
| text_separator = gr.Textbox( | |
| label="Prompt/Completion Separator", | |
| value="###", | |
| info="Symbol or text that separates prompts from completions" | |
| ) | |
| process_btn = gr.Button("Process Dataset", variant="primary") | |
| with gr.Column(): | |
| dataset_info = gr.JSON(label="Dataset Information", visible=True) | |
| preview_df = gr.Dataframe(label="Data Preview", wrap=True) | |
| dataset_feedback = gr.Markdown( | |
| "", | |
| elem_classes=["feedback-div"] | |
| ) | |
| def process_dataset_handler( | |
| file, data_format, csv_prompt, csv_completion, csv_sep, | |
| jsonl_prompt, jsonl_completion, text_sep, current_state | |
| ): | |
| if file is None: | |
| return ( | |
| current_state, | |
| None, | |
| gr.update(value="⚠️ Please upload a file first", elem_classes=["feedback-div", "error"]), | |
| None | |
| ) | |
| try: | |
| # Create a temporary file to store the uploaded content | |
| temp_dir = tempfile.mkdtemp() | |
| file_path = os.path.join(temp_dir, file.name) | |
| # Save the uploaded file to the temporary location | |
| with open(file_path, "wb") as f: | |
| f.write(file.read()) | |
| # Prepare format-specific options | |
| options = { | |
| "format": data_format.lower(), | |
| "csv_prompt_col": csv_prompt, | |
| "csv_completion_col": csv_completion, | |
| "csv_separator": csv_sep, | |
| "jsonl_prompt_key": jsonl_prompt, | |
| "jsonl_completion_key": jsonl_completion, | |
| "text_separator": text_sep | |
| } | |
| # Validate the dataset | |
| is_valid, message = validate_dataset(file_path, options) | |
| if not is_valid: | |
| return ( | |
| current_state, | |
| None, | |
| gr.update(value=f"⚠️ {message}", elem_classes=["feedback-div", "error"]), | |
| None | |
| ) | |
| # Process the dataset | |
| processed_data, stats, preview = process_dataset(file_path, options) | |
| # Update state | |
| current_state = current_state.copy() | |
| current_state["dataset_path"] = file_path | |
| current_state["processed_dataset"] = processed_data | |
| return ( | |
| current_state, | |
| stats, | |
| gr.update(value="✅ Dataset processed successfully", elem_classes=["feedback-div", "success"]), | |
| preview | |
| ) | |
| except Exception as e: | |
| return ( | |
| current_state, | |
| None, | |
| gr.update(value=f"⚠️ Error processing dataset: {str(e)}", elem_classes=["feedback-div", "error"]), | |
| None | |
| ) | |
| process_btn.click( | |
| process_dataset_handler, | |
| inputs=[ | |
| dataset_file, data_format, | |
| csv_prompt_col, csv_completion_col, csv_separator, | |
| jsonl_prompt_key, jsonl_completion_key, | |
| text_separator, state | |
| ], | |
| outputs=[state, dataset_info, dataset_feedback, preview_df] | |
| ) | |
| with gr.Tab("Model Configuration"): | |
| gr.Markdown("## Select a model and configure hyperparameters") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name = gr.Dropdown( | |
| choices=get_available_models(), | |
| label="Select Base Model", | |
| value="google/gemma-2-2b-it" | |
| ) | |
| with gr.Accordion("Training Parameters", open=True): | |
| learning_rate = gr.Slider( | |
| minimum=1e-6, maximum=1e-3, value=2e-5, step=1e-6, | |
| label="Learning Rate", | |
| info="Controls how quickly the model adapts to the training data" | |
| ) | |
| batch_size = gr.Slider( | |
| minimum=1, maximum=32, value=4, step=1, | |
| label="Batch Size", | |
| info="Number of samples processed before model weights are updated" | |
| ) | |
| num_epochs = gr.Slider( | |
| minimum=1, maximum=10, value=3, step=1, | |
| label="Number of Epochs", | |
| info="Number of complete passes through the training dataset" | |
| ) | |
| max_seq_length = gr.Slider( | |
| minimum=128, maximum=2048, value=512, step=64, | |
| label="Max Sequence Length", | |
| info="Maximum length of input sequences" | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| gradient_accumulation_steps = gr.Slider( | |
| minimum=1, maximum=16, value=1, step=1, | |
| label="Gradient Accumulation Steps", | |
| info="Accumulate gradients over multiple batches to simulate larger batch size" | |
| ) | |
| warmup_steps = gr.Slider( | |
| minimum=0, maximum=500, value=100, step=10, | |
| label="Warmup Steps", | |
| info="Number of steps for learning rate warmup" | |
| ) | |
| weight_decay = gr.Slider( | |
| minimum=0, maximum=0.1, value=0.01, step=0.001, | |
| label="Weight Decay", | |
| info="L2 regularization factor to prevent overfitting" | |
| ) | |
| lora_r = gr.Slider( | |
| minimum=1, maximum=64, value=16, step=1, | |
| label="LoRA Rank (r)", | |
| info="Rank of LoRA adaptors (lower value = smaller model)" | |
| ) | |
| lora_alpha = gr.Slider( | |
| minimum=1, maximum=64, value=32, step=1, | |
| label="LoRA Alpha", | |
| info="LoRA scaling factor (higher = stronger adaptation)" | |
| ) | |
| lora_dropout = gr.Slider( | |
| minimum=0, maximum=0.5, value=0.05, step=0.01, | |
| label="LoRA Dropout", | |
| info="Dropout probability for LoRA layers" | |
| ) | |
| save_config_btn = gr.Button("Save Configuration", variant="primary") | |
| with gr.Column(): | |
| config_info = gr.JSON(label="Current Configuration") | |
| config_feedback = gr.Markdown( | |
| "", | |
| elem_classes=["feedback-div"] | |
| ) | |
| def save_config_handler( | |
| model, lr, bs, epochs, seq_len, grad_accum, warmup, | |
| weight_decay, lora_r, lora_alpha, lora_dropout, current_state | |
| ): | |
| # Check if dataset is processed | |
| if current_state["processed_dataset"] is None: | |
| return ( | |
| current_state, | |
| None, | |
| gr.update(value="⚠️ Please process a dataset first in the Dataset Upload tab", | |
| elem_classes=["feedback-div", "error"]) | |
| ) | |
| config = { | |
| "model_name": model, | |
| "learning_rate": lr, | |
| "batch_size": bs, | |
| "num_epochs": epochs, | |
| "max_seq_length": seq_len, | |
| "gradient_accumulation_steps": grad_accum, | |
| "warmup_steps": warmup, | |
| "weight_decay": weight_decay, | |
| "lora_r": lora_r, | |
| "lora_alpha": lora_alpha, | |
| "lora_dropout": lora_dropout | |
| } | |
| # Update state | |
| current_state = current_state.copy() | |
| current_state["model_name"] = model | |
| current_state["training_params"] = config | |
| return ( | |
| current_state, | |
| config, | |
| gr.update(value="✅ Configuration saved successfully", | |
| elem_classes=["feedback-div", "success"]) | |
| ) | |
| save_config_btn.click( | |
| save_config_handler, | |
| inputs=[ | |
| model_name, learning_rate, batch_size, num_epochs, max_seq_length, | |
| gradient_accumulation_steps, warmup_steps, weight_decay, | |
| lora_r, lora_alpha, lora_dropout, state | |
| ], | |
| outputs=[state, config_info, config_feedback] | |
| ) | |
| with gr.Tab("Training"): | |
| gr.Markdown("## Train your model and monitor progress") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| start_btn = gr.Button("Start Training", variant="primary", interactive=True) | |
| stop_btn = gr.Button("Stop Training", variant="stop", interactive=False) | |
| with gr.Accordion("Training Status", open=True): | |
| status = gr.Markdown("Not started", elem_classes=["feedback-div", "info"]) | |
| progress = gr.Slider( | |
| minimum=0, maximum=100, value=0, label="Training Progress", interactive=False | |
| ) | |
| current_epoch = gr.Number(label="Current Epoch", value=0, interactive=False) | |
| current_step = gr.Number(label="Current Step", value=0, interactive=False) | |
| elapsed_time = gr.Textbox(label="Elapsed Time", value="00:00:00", interactive=False) | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| with gr.Column(): | |
| loss_plot = gr.Plot(label="Training Loss") | |
| with gr.Column(): | |
| eval_plot = gr.Plot(label="Evaluation Metrics") | |
| training_log = gr.Textbox( | |
| label="Training Log", | |
| interactive=False, | |
| lines=10 | |
| ) | |
| with gr.Accordion("Sample Generations", open=True): | |
| sample_outputs = gr.Dataframe( | |
| headers=["Prompt", "Generated Text", "Reference"], | |
| label="Sample Model Outputs", | |
| wrap=True | |
| ) | |
| # Timer for UI updates | |
| ui_update_interval = gr.Number(value=1, visible=False) | |
| def start_training_handler(current_state): | |
| # Validate state | |
| if current_state["processed_dataset"] is None: | |
| return ( | |
| current_state, | |
| gr.update(value="⚠️ Please process a dataset first", elem_classes=["feedback-div", "error"]), | |
| gr.update(interactive=True), | |
| gr.update(interactive=False) | |
| ) | |
| if current_state["training_params"] is None: | |
| return ( | |
| current_state, | |
| gr.update(value="⚠️ Please configure training parameters first", elem_classes=["feedback-div", "error"]), | |
| gr.update(interactive=True), | |
| gr.update(interactive=False) | |
| ) | |
| # Start training in a background thread | |
| try: | |
| train_thread = start_fine_tuning( | |
| model_name=current_state["model_name"], | |
| dataset=current_state["processed_dataset"], | |
| params=current_state["training_params"] | |
| ) | |
| current_state = current_state.copy() | |
| current_state["training_thread"] = train_thread | |
| return ( | |
| current_state, | |
| gr.update(value="✅ Training started", elem_classes=["feedback-div", "success"]), | |
| gr.update(interactive=False), | |
| gr.update(interactive=True) | |
| ) | |
| except Exception as e: | |
| return ( | |
| current_state, | |
| gr.update(value=f"⚠️ Error starting training: {str(e)}", elem_classes=["feedback-div", "error"]), | |
| gr.update(interactive=True), | |
| gr.update(interactive=False) | |
| ) | |
| def stop_training_handler(current_state): | |
| if "training_thread" in current_state and current_state["training_thread"] is not None: | |
| # Signal the training thread to stop | |
| current_state["training_thread"].stop() | |
| current_state = current_state.copy() | |
| current_state["training_thread"] = None | |
| return ( | |
| current_state, | |
| gr.update(value="⚠️ Training stopped by user", elem_classes=["feedback-div", "error"]), | |
| gr.update(interactive=True), | |
| gr.update(interactive=False) | |
| ) | |
| else: | |
| return ( | |
| current_state, | |
| gr.update(value="⚠️ No active training to stop", elem_classes=["feedback-div", "error"]), | |
| gr.update(interactive=True), | |
| gr.update(interactive=False) | |
| ) | |
| def update_training_ui(): | |
| training_state = load_training_state() | |
| if training_state is None: | |
| return ( | |
| 0, 0, 0, "00:00:00", None, None, "", None, | |
| gr.update(value="Not started", elem_classes=["feedback-div", "info"]) | |
| ) | |
| # Calculate progress percentage | |
| total_steps = training_state["total_steps"] | |
| current_step = training_state["current_step"] | |
| progress_pct = (current_step / total_steps * 100) if total_steps > 0 else 0 | |
| # Format elapsed time | |
| hours, remainder = divmod(training_state["elapsed_time"], 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| time_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}" | |
| # Update status message | |
| if training_state["status"] == "completed": | |
| status_msg = gr.update(value="✅ Training completed successfully", elem_classes=["feedback-div", "success"]) | |
| elif training_state["status"] == "error": | |
| status_msg = gr.update(value=f"⚠️ Training error: {training_state['error']}", elem_classes=["feedback-div", "error"]) | |
| elif training_state["status"] == "stopped": | |
| status_msg = gr.update(value="⚠️ Training stopped by user", elem_classes=["feedback-div", "error"]) | |
| else: | |
| status_msg = gr.update(value="⏳ Training in progress...", elem_classes=["feedback-div", "info"]) | |
| return ( | |
| progress_pct, | |
| training_state["current_epoch"], | |
| current_step, | |
| time_str, | |
| training_state["loss_plot"], | |
| training_state["eval_plot"], | |
| training_state["log"], | |
| training_state["samples"], | |
| status_msg | |
| ) | |
| start_btn.click( | |
| start_training_handler, | |
| inputs=[state], | |
| outputs=[state, status, start_btn, stop_btn] | |
| ) | |
| stop_btn.click( | |
| stop_training_handler, | |
| inputs=[state], | |
| outputs=[state, status, start_btn, stop_btn] | |
| ) | |
| # Remove problematic JavaScript loading approach | |
| # Create a simple manual refresh button for compatibility | |
| manual_refresh = gr.Button("Refresh Status", visible=True) | |
| manual_refresh.click( | |
| update_training_ui, | |
| inputs=None, | |
| outputs=[ | |
| progress, current_epoch, current_step, elapsed_time, | |
| loss_plot, eval_plot, training_log, sample_outputs, status | |
| ] | |
| ) | |
| # Add auto-refresh functionality with HTML component | |
| auto_refresh = gr.HTML(""" | |
| <script> | |
| // Auto-refresh the UI every second | |
| function setupAutoRefresh() { | |
| setInterval(function() { | |
| const refreshButton = document.querySelector('button:contains("Refresh Status")'); | |
| if (refreshButton) { | |
| refreshButton.click(); | |
| } | |
| }, 2000); | |
| } | |
| // Set up the auto-refresh when page loads | |
| if (window.addEventListener) { | |
| window.addEventListener('load', setupAutoRefresh, false); | |
| } | |
| </script> | |
| <p style="margin-top: 5px; font-size: 0.8em; color: #666;">Auto-refreshing status every 2 seconds</p> | |
| """) | |
| # Initial UI update | |
| demo.load( | |
| update_training_ui, | |
| inputs=None, | |
| outputs=[ | |
| progress, current_epoch, current_step, elapsed_time, | |
| loss_plot, eval_plot, training_log, sample_outputs, status | |
| ] | |
| ) | |
| with gr.Tab("Export Model"): | |
| gr.Markdown("## Export your fine-tuned model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| export_format = gr.Radio( | |
| ["PyTorch", "GGUF", "Safetensors"], | |
| label="Export Format", | |
| value="PyTorch" | |
| ) | |
| quantization = gr.Dropdown( | |
| ["None", "int8", "int4"], | |
| label="Quantization (GGUF only)", | |
| value="None", | |
| interactive=True | |
| ) | |
| model_name_input = gr.Textbox( | |
| label="Model Name", | |
| placeholder="my-fine-tuned-gemma", | |
| value="my-fine-tuned-gemma" | |
| ) | |
| output_dir = gr.Textbox( | |
| label="Output Directory", | |
| placeholder="Path to save the exported model", | |
| value="./exports" | |
| ) | |
| export_btn = gr.Button("Export Model", variant="primary") | |
| with gr.Column(): | |
| export_info = gr.JSON(label="Export Information", visible=False) | |
| export_status = gr.Markdown( | |
| "", | |
| elem_classes=["feedback-div"] | |
| ) | |
| # Fix: Remove 'visible' parameter which is not supported in this Gradio version | |
| export_progress = gr.Progress() | |
| def export_model_handler(current_state, format, quant, name, out_dir): | |
| if current_state.get("fine_tuned_model_path") is None: | |
| return ( | |
| gr.update(value="⚠️ No fine-tuned model available. Please complete training first.", | |
| elem_classes=["feedback-div", "error"]), | |
| None | |
| ) | |
| try: | |
| # Actual export would be implemented in another function | |
| export_path = os.path.join(out_dir, name) | |
| os.makedirs(export_path, exist_ok=True) | |
| export_info = { | |
| "format": format, | |
| "quantization": quant if format == "GGUF" else "None", | |
| "model_name": name, | |
| "export_path": export_path, | |
| "model_size": "0.5 GB", # This would be calculated during actual export | |
| "export_time": "00:01:23" # This would be measured during actual export | |
| } | |
| return ( | |
| gr.update(value=f"✅ Model exported successfully to {export_path}", | |
| elem_classes=["feedback-div", "success"]), | |
| export_info | |
| ) | |
| except Exception as e: | |
| return ( | |
| gr.update(value=f"⚠️ Error exporting model: {str(e)}", | |
| elem_classes=["feedback-div", "error"]), | |
| None | |
| ) | |
| export_btn.click( | |
| export_model_handler, | |
| inputs=[state, export_format, quantization, model_name_input, output_dir], | |
| # Update outputs list to remove reference to progress visibility | |
| outputs=[export_status, export_info] | |
| ) | |
| with gr.Tab("Documentation"): | |
| gr.Markdown(""" | |
| # Gemma Fine-Tuning Documentation | |
| ## Supported Models | |
| This application supports fine-tuning the following Gemma models: | |
| - google/gemma-2-2b-it | |
| - google/gemma-2-9b-it | |
| - google/gemma-2-27b-it | |
| ## Dataset Format | |
| Your dataset should follow one of these formats: | |
| ### CSV | |
| ``` | |
| prompt,completion | |
| "What is the capital of France?","The capital of France is Paris." | |
| "How does photosynthesis work?","Photosynthesis is the process..." | |
| ``` | |
| ### JSONL | |
| ``` | |
| {"prompt": "What is the capital of France?", "completion": "The capital of France is Paris."} | |
| {"prompt": "How does photosynthesis work?", "completion": "Photosynthesis is the process..."} | |
| ``` | |
| ### Plain Text | |
| ``` | |
| What is the capital of France? | |
| ### | |
| The capital of France is Paris. | |
| ### | |
| How does photosynthesis work? | |
| ### | |
| Photosynthesis is the process... | |
| ``` | |
| ## Fine-Tuning Parameters | |
| ### Basic Parameters | |
| - **Learning Rate**: Controls how quickly the model adapts to the training data. Typical values range from 1e-5 to 5e-5. | |
| - **Batch Size**: Number of samples processed before model weights are updated. Higher values require more memory. | |
| - **Number of Epochs**: Number of complete passes through the training dataset. More epochs can lead to better results but may cause overfitting. | |
| - **Max Sequence Length**: Maximum length of input sequences. Longer sequences require more memory. | |
| ### Advanced Parameters | |
| - **Gradient Accumulation Steps**: Accumulate gradients over multiple batches to simulate larger batch size. | |
| - **Warmup Steps**: Number of steps for learning rate warmup. Helps stabilize training in the early phases. | |
| - **Weight Decay**: L2 regularization factor to prevent overfitting. | |
| - **LoRA Parameters**: Controls the behavior of LoRA (Low-Rank Adaptation), a parameter-efficient fine-tuning technique. | |
| ## Export Formats | |
| - **PyTorch**: Standard PyTorch model format (.pt or .bin files with model architecture). | |
| - **GGUF**: Compact format optimized for efficient inference (especially with llama.cpp). | |
| - **Safetensors**: Safe format for storing tensors, preventing arbitrary code execution. | |
| ## Quantization | |
| Quantization reduces model size and increases inference speed at the cost of some accuracy: | |
| - **None**: No quantization, full precision (usually FP16 or BF16). | |
| - **int8**: 8-bit integer quantization, good balance of speed and accuracy. | |
| - **int4**: 4-bit integer quantization, fastest but may reduce accuracy more significantly. | |
| """) | |
| demo.launch() |