fsadeek's picture
added some features
557c6b6
# Import environment setup before any other imports
from env_setup import setup_environment
setup_environment()
import gradio as gr
import os
from model_utils import load_model, get_available_models
from data_processing import process_dataset, validate_dataset
from fine_tuning import start_fine_tuning, load_training_state
import tempfile
CSS = """
.feedback-div {
padding: 10px;
margin-bottom: 10px;
border-radius: 5px;
}
.success {
background-color: #d4edda;
color: #155724;
border: 1px solid #c3e6cb;
}
.error {
background-color: #f8d7da;
color: #721c24;
border: 1px solid #f5c6cb;
}
.info {
background-color: #d1ecf1;
color: #0c5460;
border: 1px solid #bee5eb;
}
"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
# Store state across tabs
state = gr.State({
"dataset_path": None,
"processed_dataset": None,
"model_name": None,
"model_instance": None,
"training_params": None,
"fine_tuned_model_path": None,
"training_logs": []
})
with gr.Sidebar():
gr.Markdown("# Gemma Fine-Tuning UI")
gr.Markdown("Sign in with your Hugging Face account to use the Nebius API for inference and model access.")
button = gr.LoginButton("Sign in")
gr.Markdown("## Navigation")
with gr.Tab("Introduction"):
gr.Markdown("""
# Welcome to Gemma Fine-Tuning UI
This application allows you to fine-tune Google's Gemma models on your own datasets with a user-friendly interface.
## Features:
- Upload and preprocess your datasets in various formats (CSV, JSONL, TXT)
- Configure model hyperparameters for optimal performance
- Visualize training progress in real-time
- Export your fine-tuned model in different formats
## Getting Started:
1. Navigate to the **Dataset Upload** tab to prepare your data
2. Configure your model and hyperparameters in the **Model Configuration** tab
3. Start and monitor training in the **Training** tab
4. Export your fine-tuned model in the **Export Model** tab
For more details, check the Documentation tab.
""")
with gr.Tab("Dataset Upload"):
gr.Markdown("## Upload and prepare your dataset for fine-tuning")
with gr.Row():
with gr.Column():
dataset_file = gr.File(
label="Upload Dataset File (CSV, JSONL, or TXT)",
file_types=["csv", "jsonl", "json", "txt"]
)
data_format = gr.Radio(
["CSV", "JSONL", "Plain Text"],
label="Data Format",
value="CSV"
)
with gr.Accordion("CSV Options", open=False):
csv_prompt_col = gr.Textbox(label="Prompt Column Name", value="prompt")
csv_completion_col = gr.Textbox(label="Completion Column Name", value="completion")
csv_separator = gr.Textbox(label="Column Separator", value=",")
with gr.Accordion("JSONL Options", open=False):
jsonl_prompt_key = gr.Textbox(label="Prompt Key", value="prompt")
jsonl_completion_key = gr.Textbox(label="Completion Key", value="completion")
with gr.Accordion("Text Options", open=False):
text_separator = gr.Textbox(
label="Prompt/Completion Separator",
value="###",
info="Symbol or text that separates prompts from completions"
)
process_btn = gr.Button("Process Dataset", variant="primary")
with gr.Column():
dataset_info = gr.JSON(label="Dataset Information", visible=True)
preview_df = gr.Dataframe(label="Data Preview", wrap=True)
dataset_feedback = gr.Markdown(
"",
elem_classes=["feedback-div"]
)
def process_dataset_handler(
file, data_format, csv_prompt, csv_completion, csv_sep,
jsonl_prompt, jsonl_completion, text_sep, current_state
):
if file is None:
return (
current_state,
None,
gr.update(value="⚠️ Please upload a file first", elem_classes=["feedback-div", "error"]),
None
)
try:
# Create a temporary file to store the uploaded content
temp_dir = tempfile.mkdtemp()
file_path = os.path.join(temp_dir, file.name)
# Save the uploaded file to the temporary location
with open(file_path, "wb") as f:
f.write(file.read())
# Prepare format-specific options
options = {
"format": data_format.lower(),
"csv_prompt_col": csv_prompt,
"csv_completion_col": csv_completion,
"csv_separator": csv_sep,
"jsonl_prompt_key": jsonl_prompt,
"jsonl_completion_key": jsonl_completion,
"text_separator": text_sep
}
# Validate the dataset
is_valid, message = validate_dataset(file_path, options)
if not is_valid:
return (
current_state,
None,
gr.update(value=f"⚠️ {message}", elem_classes=["feedback-div", "error"]),
None
)
# Process the dataset
processed_data, stats, preview = process_dataset(file_path, options)
# Update state
current_state = current_state.copy()
current_state["dataset_path"] = file_path
current_state["processed_dataset"] = processed_data
return (
current_state,
stats,
gr.update(value="✅ Dataset processed successfully", elem_classes=["feedback-div", "success"]),
preview
)
except Exception as e:
return (
current_state,
None,
gr.update(value=f"⚠️ Error processing dataset: {str(e)}", elem_classes=["feedback-div", "error"]),
None
)
process_btn.click(
process_dataset_handler,
inputs=[
dataset_file, data_format,
csv_prompt_col, csv_completion_col, csv_separator,
jsonl_prompt_key, jsonl_completion_key,
text_separator, state
],
outputs=[state, dataset_info, dataset_feedback, preview_df]
)
with gr.Tab("Model Configuration"):
gr.Markdown("## Select a model and configure hyperparameters")
with gr.Row():
with gr.Column():
model_name = gr.Dropdown(
choices=get_available_models(),
label="Select Base Model",
value="google/gemma-2-2b-it"
)
with gr.Accordion("Training Parameters", open=True):
learning_rate = gr.Slider(
minimum=1e-6, maximum=1e-3, value=2e-5, step=1e-6,
label="Learning Rate",
info="Controls how quickly the model adapts to the training data"
)
batch_size = gr.Slider(
minimum=1, maximum=32, value=4, step=1,
label="Batch Size",
info="Number of samples processed before model weights are updated"
)
num_epochs = gr.Slider(
minimum=1, maximum=10, value=3, step=1,
label="Number of Epochs",
info="Number of complete passes through the training dataset"
)
max_seq_length = gr.Slider(
minimum=128, maximum=2048, value=512, step=64,
label="Max Sequence Length",
info="Maximum length of input sequences"
)
with gr.Accordion("Advanced Options", open=False):
gradient_accumulation_steps = gr.Slider(
minimum=1, maximum=16, value=1, step=1,
label="Gradient Accumulation Steps",
info="Accumulate gradients over multiple batches to simulate larger batch size"
)
warmup_steps = gr.Slider(
minimum=0, maximum=500, value=100, step=10,
label="Warmup Steps",
info="Number of steps for learning rate warmup"
)
weight_decay = gr.Slider(
minimum=0, maximum=0.1, value=0.01, step=0.001,
label="Weight Decay",
info="L2 regularization factor to prevent overfitting"
)
lora_r = gr.Slider(
minimum=1, maximum=64, value=16, step=1,
label="LoRA Rank (r)",
info="Rank of LoRA adaptors (lower value = smaller model)"
)
lora_alpha = gr.Slider(
minimum=1, maximum=64, value=32, step=1,
label="LoRA Alpha",
info="LoRA scaling factor (higher = stronger adaptation)"
)
lora_dropout = gr.Slider(
minimum=0, maximum=0.5, value=0.05, step=0.01,
label="LoRA Dropout",
info="Dropout probability for LoRA layers"
)
save_config_btn = gr.Button("Save Configuration", variant="primary")
with gr.Column():
config_info = gr.JSON(label="Current Configuration")
config_feedback = gr.Markdown(
"",
elem_classes=["feedback-div"]
)
def save_config_handler(
model, lr, bs, epochs, seq_len, grad_accum, warmup,
weight_decay, lora_r, lora_alpha, lora_dropout, current_state
):
# Check if dataset is processed
if current_state["processed_dataset"] is None:
return (
current_state,
None,
gr.update(value="⚠️ Please process a dataset first in the Dataset Upload tab",
elem_classes=["feedback-div", "error"])
)
config = {
"model_name": model,
"learning_rate": lr,
"batch_size": bs,
"num_epochs": epochs,
"max_seq_length": seq_len,
"gradient_accumulation_steps": grad_accum,
"warmup_steps": warmup,
"weight_decay": weight_decay,
"lora_r": lora_r,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout
}
# Update state
current_state = current_state.copy()
current_state["model_name"] = model
current_state["training_params"] = config
return (
current_state,
config,
gr.update(value="✅ Configuration saved successfully",
elem_classes=["feedback-div", "success"])
)
save_config_btn.click(
save_config_handler,
inputs=[
model_name, learning_rate, batch_size, num_epochs, max_seq_length,
gradient_accumulation_steps, warmup_steps, weight_decay,
lora_r, lora_alpha, lora_dropout, state
],
outputs=[state, config_info, config_feedback]
)
with gr.Tab("Training"):
gr.Markdown("## Train your model and monitor progress")
with gr.Row():
with gr.Column(scale=1):
start_btn = gr.Button("Start Training", variant="primary", interactive=True)
stop_btn = gr.Button("Stop Training", variant="stop", interactive=False)
with gr.Accordion("Training Status", open=True):
status = gr.Markdown("Not started", elem_classes=["feedback-div", "info"])
progress = gr.Slider(
minimum=0, maximum=100, value=0, label="Training Progress", interactive=False
)
current_epoch = gr.Number(label="Current Epoch", value=0, interactive=False)
current_step = gr.Number(label="Current Step", value=0, interactive=False)
elapsed_time = gr.Textbox(label="Elapsed Time", value="00:00:00", interactive=False)
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
loss_plot = gr.Plot(label="Training Loss")
with gr.Column():
eval_plot = gr.Plot(label="Evaluation Metrics")
training_log = gr.Textbox(
label="Training Log",
interactive=False,
lines=10
)
with gr.Accordion("Sample Generations", open=True):
sample_outputs = gr.Dataframe(
headers=["Prompt", "Generated Text", "Reference"],
label="Sample Model Outputs",
wrap=True
)
# Timer for UI updates
ui_update_interval = gr.Number(value=1, visible=False)
def start_training_handler(current_state):
# Validate state
if current_state["processed_dataset"] is None:
return (
current_state,
gr.update(value="⚠️ Please process a dataset first", elem_classes=["feedback-div", "error"]),
gr.update(interactive=True),
gr.update(interactive=False)
)
if current_state["training_params"] is None:
return (
current_state,
gr.update(value="⚠️ Please configure training parameters first", elem_classes=["feedback-div", "error"]),
gr.update(interactive=True),
gr.update(interactive=False)
)
# Start training in a background thread
try:
train_thread = start_fine_tuning(
model_name=current_state["model_name"],
dataset=current_state["processed_dataset"],
params=current_state["training_params"]
)
current_state = current_state.copy()
current_state["training_thread"] = train_thread
return (
current_state,
gr.update(value="✅ Training started", elem_classes=["feedback-div", "success"]),
gr.update(interactive=False),
gr.update(interactive=True)
)
except Exception as e:
return (
current_state,
gr.update(value=f"⚠️ Error starting training: {str(e)}", elem_classes=["feedback-div", "error"]),
gr.update(interactive=True),
gr.update(interactive=False)
)
def stop_training_handler(current_state):
if "training_thread" in current_state and current_state["training_thread"] is not None:
# Signal the training thread to stop
current_state["training_thread"].stop()
current_state = current_state.copy()
current_state["training_thread"] = None
return (
current_state,
gr.update(value="⚠️ Training stopped by user", elem_classes=["feedback-div", "error"]),
gr.update(interactive=True),
gr.update(interactive=False)
)
else:
return (
current_state,
gr.update(value="⚠️ No active training to stop", elem_classes=["feedback-div", "error"]),
gr.update(interactive=True),
gr.update(interactive=False)
)
def update_training_ui():
training_state = load_training_state()
if training_state is None:
return (
0, 0, 0, "00:00:00", None, None, "", None,
gr.update(value="Not started", elem_classes=["feedback-div", "info"])
)
# Calculate progress percentage
total_steps = training_state["total_steps"]
current_step = training_state["current_step"]
progress_pct = (current_step / total_steps * 100) if total_steps > 0 else 0
# Format elapsed time
hours, remainder = divmod(training_state["elapsed_time"], 3600)
minutes, seconds = divmod(remainder, 60)
time_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
# Update status message
if training_state["status"] == "completed":
status_msg = gr.update(value="✅ Training completed successfully", elem_classes=["feedback-div", "success"])
elif training_state["status"] == "error":
status_msg = gr.update(value=f"⚠️ Training error: {training_state['error']}", elem_classes=["feedback-div", "error"])
elif training_state["status"] == "stopped":
status_msg = gr.update(value="⚠️ Training stopped by user", elem_classes=["feedback-div", "error"])
else:
status_msg = gr.update(value="⏳ Training in progress...", elem_classes=["feedback-div", "info"])
return (
progress_pct,
training_state["current_epoch"],
current_step,
time_str,
training_state["loss_plot"],
training_state["eval_plot"],
training_state["log"],
training_state["samples"],
status_msg
)
start_btn.click(
start_training_handler,
inputs=[state],
outputs=[state, status, start_btn, stop_btn]
)
stop_btn.click(
stop_training_handler,
inputs=[state],
outputs=[state, status, start_btn, stop_btn]
)
# Remove problematic JavaScript loading approach
# Create a simple manual refresh button for compatibility
manual_refresh = gr.Button("Refresh Status", visible=True)
manual_refresh.click(
update_training_ui,
inputs=None,
outputs=[
progress, current_epoch, current_step, elapsed_time,
loss_plot, eval_plot, training_log, sample_outputs, status
]
)
# Add auto-refresh functionality with HTML component
auto_refresh = gr.HTML("""
<script>
// Auto-refresh the UI every second
function setupAutoRefresh() {
setInterval(function() {
const refreshButton = document.querySelector('button:contains("Refresh Status")');
if (refreshButton) {
refreshButton.click();
}
}, 2000);
}
// Set up the auto-refresh when page loads
if (window.addEventListener) {
window.addEventListener('load', setupAutoRefresh, false);
}
</script>
<p style="margin-top: 5px; font-size: 0.8em; color: #666;">Auto-refreshing status every 2 seconds</p>
""")
# Initial UI update
demo.load(
update_training_ui,
inputs=None,
outputs=[
progress, current_epoch, current_step, elapsed_time,
loss_plot, eval_plot, training_log, sample_outputs, status
]
)
with gr.Tab("Export Model"):
gr.Markdown("## Export your fine-tuned model")
with gr.Row():
with gr.Column():
export_format = gr.Radio(
["PyTorch", "GGUF", "Safetensors"],
label="Export Format",
value="PyTorch"
)
quantization = gr.Dropdown(
["None", "int8", "int4"],
label="Quantization (GGUF only)",
value="None",
interactive=True
)
model_name_input = gr.Textbox(
label="Model Name",
placeholder="my-fine-tuned-gemma",
value="my-fine-tuned-gemma"
)
output_dir = gr.Textbox(
label="Output Directory",
placeholder="Path to save the exported model",
value="./exports"
)
export_btn = gr.Button("Export Model", variant="primary")
with gr.Column():
export_info = gr.JSON(label="Export Information", visible=False)
export_status = gr.Markdown(
"",
elem_classes=["feedback-div"]
)
# Fix: Remove 'visible' parameter which is not supported in this Gradio version
export_progress = gr.Progress()
def export_model_handler(current_state, format, quant, name, out_dir):
if current_state.get("fine_tuned_model_path") is None:
return (
gr.update(value="⚠️ No fine-tuned model available. Please complete training first.",
elem_classes=["feedback-div", "error"]),
None
)
try:
# Actual export would be implemented in another function
export_path = os.path.join(out_dir, name)
os.makedirs(export_path, exist_ok=True)
export_info = {
"format": format,
"quantization": quant if format == "GGUF" else "None",
"model_name": name,
"export_path": export_path,
"model_size": "0.5 GB", # This would be calculated during actual export
"export_time": "00:01:23" # This would be measured during actual export
}
return (
gr.update(value=f"✅ Model exported successfully to {export_path}",
elem_classes=["feedback-div", "success"]),
export_info
)
except Exception as e:
return (
gr.update(value=f"⚠️ Error exporting model: {str(e)}",
elem_classes=["feedback-div", "error"]),
None
)
export_btn.click(
export_model_handler,
inputs=[state, export_format, quantization, model_name_input, output_dir],
# Update outputs list to remove reference to progress visibility
outputs=[export_status, export_info]
)
with gr.Tab("Documentation"):
gr.Markdown("""
# Gemma Fine-Tuning Documentation
## Supported Models
This application supports fine-tuning the following Gemma models:
- google/gemma-2-2b-it
- google/gemma-2-9b-it
- google/gemma-2-27b-it
## Dataset Format
Your dataset should follow one of these formats:
### CSV
```
prompt,completion
"What is the capital of France?","The capital of France is Paris."
"How does photosynthesis work?","Photosynthesis is the process..."
```
### JSONL
```
{"prompt": "What is the capital of France?", "completion": "The capital of France is Paris."}
{"prompt": "How does photosynthesis work?", "completion": "Photosynthesis is the process..."}
```
### Plain Text
```
What is the capital of France?
###
The capital of France is Paris.
###
How does photosynthesis work?
###
Photosynthesis is the process...
```
## Fine-Tuning Parameters
### Basic Parameters
- **Learning Rate**: Controls how quickly the model adapts to the training data. Typical values range from 1e-5 to 5e-5.
- **Batch Size**: Number of samples processed before model weights are updated. Higher values require more memory.
- **Number of Epochs**: Number of complete passes through the training dataset. More epochs can lead to better results but may cause overfitting.
- **Max Sequence Length**: Maximum length of input sequences. Longer sequences require more memory.
### Advanced Parameters
- **Gradient Accumulation Steps**: Accumulate gradients over multiple batches to simulate larger batch size.
- **Warmup Steps**: Number of steps for learning rate warmup. Helps stabilize training in the early phases.
- **Weight Decay**: L2 regularization factor to prevent overfitting.
- **LoRA Parameters**: Controls the behavior of LoRA (Low-Rank Adaptation), a parameter-efficient fine-tuning technique.
## Export Formats
- **PyTorch**: Standard PyTorch model format (.pt or .bin files with model architecture).
- **GGUF**: Compact format optimized for efficient inference (especially with llama.cpp).
- **Safetensors**: Safe format for storing tensors, preventing arbitrary code execution.
## Quantization
Quantization reduces model size and increases inference speed at the cost of some accuracy:
- **None**: No quantization, full precision (usually FP16 or BF16).
- **int8**: 8-bit integer quantization, good balance of speed and accuracy.
- **int4**: 4-bit integer quantization, fastest but may reduce accuracy more significantly.
""")
demo.launch()