fsadeek commited on
Commit
557c6b6
·
1 Parent(s): 08c68bc

added some features

Browse files
README.md CHANGED
@@ -9,7 +9,56 @@ app_file: app.py
9
  pinned: false
10
  hf_oauth: true
11
  hf_oauth_scopes:
12
- - inference-api
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  hf_oauth: true
11
  hf_oauth_scopes:
12
+ - inference-api
13
  ---
14
 
15
+ # Gemma Fine-Tuning UI
16
+
17
+ A user-friendly web interface for fine-tuning Google's Gemma models on custom datasets.
18
+
19
+ ## Features
20
+
21
+ - **Easy Dataset Upload**: Support for CSV, JSONL, and plain text formats
22
+ - **Intuitive Hyperparameter Configuration**: Adjust learning rates, batch sizes, and other parameters with visual controls
23
+ - **Real-time Training Visualization**: Monitor loss curves, evaluation metrics, and sample outputs during training
24
+ - **Flexible Model Export**: Download your fine-tuned model in PyTorch, GGUF, or Safetensors formats
25
+ - **Comprehensive Documentation**: Built-in guidance for fine-tuning process
26
+
27
+ ## Getting Started
28
+
29
+ ### Prerequisites
30
+
31
+ - Python 3.8 or later
32
+ - PyTorch 2.0 or later
33
+ - Hugging Face account with access to Gemma models
34
+
35
+ ### Installation
36
+
37
+ 1. Clone this repository:
38
+
39
+ ```bash
40
+ git clone https://github.com/yourusername/gemma-fine-tuning.git
41
+ cd gemma-fine-tuning
42
+ ```
43
+
44
+ 2. Install the required packages:
45
+
46
+ ```bash
47
+ pip install -r requirements.txt
48
+ ```
49
+
50
+ 3. Launch the application:
51
+
52
+ ```bash
53
+ python app.py
54
+ ```
55
+
56
+ 4. Open your browser and navigate to `http://localhost:7860`
57
+
58
+ ## Usage Guide
59
+
60
+ ### 1. Dataset Preparation
61
+
62
+ Prepare your dataset in one of the supported formats:
63
+
64
+ **CSV format**:
__pycache__/data_processing.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
__pycache__/env_setup.cpython-311.pyc ADDED
Binary file (1.01 kB). View file
 
__pycache__/fine_tuning.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (6.27 kB). View file
 
app.py CHANGED
@@ -1,18 +1,670 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- with gr.Blocks(fill_height=True) as demo:
 
 
 
 
 
 
 
 
 
 
 
4
  with gr.Sidebar():
5
- gr.Markdown("# Inference Provider")
6
- gr.Markdown("This Space showcases the google/gemma-2-2b-it model, served by the nebius API. Sign in with your Hugging Face account to use this API.")
7
  button = gr.LoginButton("Sign in")
8
- gr.load("models/google/gemma-2-2b-it", accept_token=button, provider="nebius")
9
-
10
- with demo.route("Interface") as incrementer_demo:
11
- gr.Markdown("This is the second page")
12
- gr.Textbox()
13
-
14
- with demo.route("Test") as incrementer_demo:
15
- gr.Markdown("This is the second page")
16
- gr.Textbox()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  demo.launch()
 
1
+ # Import environment setup before any other imports
2
+ from env_setup import setup_environment
3
+ setup_environment()
4
+
5
  import gradio as gr
6
+ import os
7
+ from model_utils import load_model, get_available_models
8
+ from data_processing import process_dataset, validate_dataset
9
+ from fine_tuning import start_fine_tuning, load_training_state
10
+ import tempfile
11
+
12
+ CSS = """
13
+ .feedback-div {
14
+ padding: 10px;
15
+ margin-bottom: 10px;
16
+ border-radius: 5px;
17
+ }
18
+ .success {
19
+ background-color: #d4edda;
20
+ color: #155724;
21
+ border: 1px solid #c3e6cb;
22
+ }
23
+ .error {
24
+ background-color: #f8d7da;
25
+ color: #721c24;
26
+ border: 1px solid #f5c6cb;
27
+ }
28
+ .info {
29
+ background-color: #d1ecf1;
30
+ color: #0c5460;
31
+ border: 1px solid #bee5eb;
32
+ }
33
+ """
34
 
35
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
36
+ # Store state across tabs
37
+ state = gr.State({
38
+ "dataset_path": None,
39
+ "processed_dataset": None,
40
+ "model_name": None,
41
+ "model_instance": None,
42
+ "training_params": None,
43
+ "fine_tuned_model_path": None,
44
+ "training_logs": []
45
+ })
46
+
47
  with gr.Sidebar():
48
+ gr.Markdown("# Gemma Fine-Tuning UI")
49
+ gr.Markdown("Sign in with your Hugging Face account to use the Nebius API for inference and model access.")
50
  button = gr.LoginButton("Sign in")
51
+
52
+ gr.Markdown("## Navigation")
53
+
54
+ with gr.Tab("Introduction"):
55
+ gr.Markdown("""
56
+ # Welcome to Gemma Fine-Tuning UI
57
+
58
+ This application allows you to fine-tune Google's Gemma models on your own datasets with a user-friendly interface.
59
+
60
+ ## Features:
61
+ - Upload and preprocess your datasets in various formats (CSV, JSONL, TXT)
62
+ - Configure model hyperparameters for optimal performance
63
+ - Visualize training progress in real-time
64
+ - Export your fine-tuned model in different formats
65
+
66
+ ## Getting Started:
67
+ 1. Navigate to the **Dataset Upload** tab to prepare your data
68
+ 2. Configure your model and hyperparameters in the **Model Configuration** tab
69
+ 3. Start and monitor training in the **Training** tab
70
+ 4. Export your fine-tuned model in the **Export Model** tab
71
+
72
+ For more details, check the Documentation tab.
73
+ """)
74
+
75
+ with gr.Tab("Dataset Upload"):
76
+ gr.Markdown("## Upload and prepare your dataset for fine-tuning")
77
+
78
+ with gr.Row():
79
+ with gr.Column():
80
+ dataset_file = gr.File(
81
+ label="Upload Dataset File (CSV, JSONL, or TXT)",
82
+ file_types=["csv", "jsonl", "json", "txt"]
83
+ )
84
+
85
+ data_format = gr.Radio(
86
+ ["CSV", "JSONL", "Plain Text"],
87
+ label="Data Format",
88
+ value="CSV"
89
+ )
90
+
91
+ with gr.Accordion("CSV Options", open=False):
92
+ csv_prompt_col = gr.Textbox(label="Prompt Column Name", value="prompt")
93
+ csv_completion_col = gr.Textbox(label="Completion Column Name", value="completion")
94
+ csv_separator = gr.Textbox(label="Column Separator", value=",")
95
+
96
+ with gr.Accordion("JSONL Options", open=False):
97
+ jsonl_prompt_key = gr.Textbox(label="Prompt Key", value="prompt")
98
+ jsonl_completion_key = gr.Textbox(label="Completion Key", value="completion")
99
+
100
+ with gr.Accordion("Text Options", open=False):
101
+ text_separator = gr.Textbox(
102
+ label="Prompt/Completion Separator",
103
+ value="###",
104
+ info="Symbol or text that separates prompts from completions"
105
+ )
106
+
107
+ process_btn = gr.Button("Process Dataset", variant="primary")
108
+
109
+ with gr.Column():
110
+ dataset_info = gr.JSON(label="Dataset Information", visible=True)
111
+ preview_df = gr.Dataframe(label="Data Preview", wrap=True)
112
+ dataset_feedback = gr.Markdown(
113
+ "",
114
+ elem_classes=["feedback-div"]
115
+ )
116
+
117
+ def process_dataset_handler(
118
+ file, data_format, csv_prompt, csv_completion, csv_sep,
119
+ jsonl_prompt, jsonl_completion, text_sep, current_state
120
+ ):
121
+ if file is None:
122
+ return (
123
+ current_state,
124
+ None,
125
+ gr.update(value="⚠️ Please upload a file first", elem_classes=["feedback-div", "error"]),
126
+ None
127
+ )
128
+
129
+ try:
130
+ # Create a temporary file to store the uploaded content
131
+ temp_dir = tempfile.mkdtemp()
132
+ file_path = os.path.join(temp_dir, file.name)
133
+
134
+ # Save the uploaded file to the temporary location
135
+ with open(file_path, "wb") as f:
136
+ f.write(file.read())
137
+
138
+ # Prepare format-specific options
139
+ options = {
140
+ "format": data_format.lower(),
141
+ "csv_prompt_col": csv_prompt,
142
+ "csv_completion_col": csv_completion,
143
+ "csv_separator": csv_sep,
144
+ "jsonl_prompt_key": jsonl_prompt,
145
+ "jsonl_completion_key": jsonl_completion,
146
+ "text_separator": text_sep
147
+ }
148
+
149
+ # Validate the dataset
150
+ is_valid, message = validate_dataset(file_path, options)
151
+ if not is_valid:
152
+ return (
153
+ current_state,
154
+ None,
155
+ gr.update(value=f"⚠️ {message}", elem_classes=["feedback-div", "error"]),
156
+ None
157
+ )
158
+
159
+ # Process the dataset
160
+ processed_data, stats, preview = process_dataset(file_path, options)
161
+
162
+ # Update state
163
+ current_state = current_state.copy()
164
+ current_state["dataset_path"] = file_path
165
+ current_state["processed_dataset"] = processed_data
166
+
167
+ return (
168
+ current_state,
169
+ stats,
170
+ gr.update(value="✅ Dataset processed successfully", elem_classes=["feedback-div", "success"]),
171
+ preview
172
+ )
173
+
174
+ except Exception as e:
175
+ return (
176
+ current_state,
177
+ None,
178
+ gr.update(value=f"⚠️ Error processing dataset: {str(e)}", elem_classes=["feedback-div", "error"]),
179
+ None
180
+ )
181
+
182
+ process_btn.click(
183
+ process_dataset_handler,
184
+ inputs=[
185
+ dataset_file, data_format,
186
+ csv_prompt_col, csv_completion_col, csv_separator,
187
+ jsonl_prompt_key, jsonl_completion_key,
188
+ text_separator, state
189
+ ],
190
+ outputs=[state, dataset_info, dataset_feedback, preview_df]
191
+ )
192
+
193
+ with gr.Tab("Model Configuration"):
194
+ gr.Markdown("## Select a model and configure hyperparameters")
195
+
196
+ with gr.Row():
197
+ with gr.Column():
198
+ model_name = gr.Dropdown(
199
+ choices=get_available_models(),
200
+ label="Select Base Model",
201
+ value="google/gemma-2-2b-it"
202
+ )
203
+
204
+ with gr.Accordion("Training Parameters", open=True):
205
+ learning_rate = gr.Slider(
206
+ minimum=1e-6, maximum=1e-3, value=2e-5, step=1e-6,
207
+ label="Learning Rate",
208
+ info="Controls how quickly the model adapts to the training data"
209
+ )
210
+ batch_size = gr.Slider(
211
+ minimum=1, maximum=32, value=4, step=1,
212
+ label="Batch Size",
213
+ info="Number of samples processed before model weights are updated"
214
+ )
215
+ num_epochs = gr.Slider(
216
+ minimum=1, maximum=10, value=3, step=1,
217
+ label="Number of Epochs",
218
+ info="Number of complete passes through the training dataset"
219
+ )
220
+ max_seq_length = gr.Slider(
221
+ minimum=128, maximum=2048, value=512, step=64,
222
+ label="Max Sequence Length",
223
+ info="Maximum length of input sequences"
224
+ )
225
+
226
+ with gr.Accordion("Advanced Options", open=False):
227
+ gradient_accumulation_steps = gr.Slider(
228
+ minimum=1, maximum=16, value=1, step=1,
229
+ label="Gradient Accumulation Steps",
230
+ info="Accumulate gradients over multiple batches to simulate larger batch size"
231
+ )
232
+ warmup_steps = gr.Slider(
233
+ minimum=0, maximum=500, value=100, step=10,
234
+ label="Warmup Steps",
235
+ info="Number of steps for learning rate warmup"
236
+ )
237
+ weight_decay = gr.Slider(
238
+ minimum=0, maximum=0.1, value=0.01, step=0.001,
239
+ label="Weight Decay",
240
+ info="L2 regularization factor to prevent overfitting"
241
+ )
242
+ lora_r = gr.Slider(
243
+ minimum=1, maximum=64, value=16, step=1,
244
+ label="LoRA Rank (r)",
245
+ info="Rank of LoRA adaptors (lower value = smaller model)"
246
+ )
247
+ lora_alpha = gr.Slider(
248
+ minimum=1, maximum=64, value=32, step=1,
249
+ label="LoRA Alpha",
250
+ info="LoRA scaling factor (higher = stronger adaptation)"
251
+ )
252
+ lora_dropout = gr.Slider(
253
+ minimum=0, maximum=0.5, value=0.05, step=0.01,
254
+ label="LoRA Dropout",
255
+ info="Dropout probability for LoRA layers"
256
+ )
257
+
258
+ save_config_btn = gr.Button("Save Configuration", variant="primary")
259
+
260
+ with gr.Column():
261
+ config_info = gr.JSON(label="Current Configuration")
262
+ config_feedback = gr.Markdown(
263
+ "",
264
+ elem_classes=["feedback-div"]
265
+ )
266
+
267
+ def save_config_handler(
268
+ model, lr, bs, epochs, seq_len, grad_accum, warmup,
269
+ weight_decay, lora_r, lora_alpha, lora_dropout, current_state
270
+ ):
271
+ # Check if dataset is processed
272
+ if current_state["processed_dataset"] is None:
273
+ return (
274
+ current_state,
275
+ None,
276
+ gr.update(value="⚠️ Please process a dataset first in the Dataset Upload tab",
277
+ elem_classes=["feedback-div", "error"])
278
+ )
279
+
280
+ config = {
281
+ "model_name": model,
282
+ "learning_rate": lr,
283
+ "batch_size": bs,
284
+ "num_epochs": epochs,
285
+ "max_seq_length": seq_len,
286
+ "gradient_accumulation_steps": grad_accum,
287
+ "warmup_steps": warmup,
288
+ "weight_decay": weight_decay,
289
+ "lora_r": lora_r,
290
+ "lora_alpha": lora_alpha,
291
+ "lora_dropout": lora_dropout
292
+ }
293
+
294
+ # Update state
295
+ current_state = current_state.copy()
296
+ current_state["model_name"] = model
297
+ current_state["training_params"] = config
298
+
299
+ return (
300
+ current_state,
301
+ config,
302
+ gr.update(value="✅ Configuration saved successfully",
303
+ elem_classes=["feedback-div", "success"])
304
+ )
305
+
306
+ save_config_btn.click(
307
+ save_config_handler,
308
+ inputs=[
309
+ model_name, learning_rate, batch_size, num_epochs, max_seq_length,
310
+ gradient_accumulation_steps, warmup_steps, weight_decay,
311
+ lora_r, lora_alpha, lora_dropout, state
312
+ ],
313
+ outputs=[state, config_info, config_feedback]
314
+ )
315
+
316
+ with gr.Tab("Training"):
317
+ gr.Markdown("## Train your model and monitor progress")
318
+
319
+ with gr.Row():
320
+ with gr.Column(scale=1):
321
+ start_btn = gr.Button("Start Training", variant="primary", interactive=True)
322
+ stop_btn = gr.Button("Stop Training", variant="stop", interactive=False)
323
+
324
+ with gr.Accordion("Training Status", open=True):
325
+ status = gr.Markdown("Not started", elem_classes=["feedback-div", "info"])
326
+ progress = gr.Slider(
327
+ minimum=0, maximum=100, value=0, label="Training Progress", interactive=False
328
+ )
329
+ current_epoch = gr.Number(label="Current Epoch", value=0, interactive=False)
330
+ current_step = gr.Number(label="Current Step", value=0, interactive=False)
331
+ elapsed_time = gr.Textbox(label="Elapsed Time", value="00:00:00", interactive=False)
332
+
333
+ with gr.Column(scale=2):
334
+ with gr.Row():
335
+ with gr.Column():
336
+ loss_plot = gr.Plot(label="Training Loss")
337
+ with gr.Column():
338
+ eval_plot = gr.Plot(label="Evaluation Metrics")
339
+
340
+ training_log = gr.Textbox(
341
+ label="Training Log",
342
+ interactive=False,
343
+ lines=10
344
+ )
345
+
346
+ with gr.Accordion("Sample Generations", open=True):
347
+ sample_outputs = gr.Dataframe(
348
+ headers=["Prompt", "Generated Text", "Reference"],
349
+ label="Sample Model Outputs",
350
+ wrap=True
351
+ )
352
+
353
+ # Timer for UI updates
354
+ ui_update_interval = gr.Number(value=1, visible=False)
355
+
356
+ def start_training_handler(current_state):
357
+ # Validate state
358
+ if current_state["processed_dataset"] is None:
359
+ return (
360
+ current_state,
361
+ gr.update(value="⚠️ Please process a dataset first", elem_classes=["feedback-div", "error"]),
362
+ gr.update(interactive=True),
363
+ gr.update(interactive=False)
364
+ )
365
+
366
+ if current_state["training_params"] is None:
367
+ return (
368
+ current_state,
369
+ gr.update(value="⚠️ Please configure training parameters first", elem_classes=["feedback-div", "error"]),
370
+ gr.update(interactive=True),
371
+ gr.update(interactive=False)
372
+ )
373
+
374
+ # Start training in a background thread
375
+ try:
376
+ train_thread = start_fine_tuning(
377
+ model_name=current_state["model_name"],
378
+ dataset=current_state["processed_dataset"],
379
+ params=current_state["training_params"]
380
+ )
381
+
382
+ current_state = current_state.copy()
383
+ current_state["training_thread"] = train_thread
384
+
385
+ return (
386
+ current_state,
387
+ gr.update(value="✅ Training started", elem_classes=["feedback-div", "success"]),
388
+ gr.update(interactive=False),
389
+ gr.update(interactive=True)
390
+ )
391
+ except Exception as e:
392
+ return (
393
+ current_state,
394
+ gr.update(value=f"⚠️ Error starting training: {str(e)}", elem_classes=["feedback-div", "error"]),
395
+ gr.update(interactive=True),
396
+ gr.update(interactive=False)
397
+ )
398
+
399
+ def stop_training_handler(current_state):
400
+ if "training_thread" in current_state and current_state["training_thread"] is not None:
401
+ # Signal the training thread to stop
402
+ current_state["training_thread"].stop()
403
+
404
+ current_state = current_state.copy()
405
+ current_state["training_thread"] = None
406
+
407
+ return (
408
+ current_state,
409
+ gr.update(value="⚠️ Training stopped by user", elem_classes=["feedback-div", "error"]),
410
+ gr.update(interactive=True),
411
+ gr.update(interactive=False)
412
+ )
413
+ else:
414
+ return (
415
+ current_state,
416
+ gr.update(value="⚠️ No active training to stop", elem_classes=["feedback-div", "error"]),
417
+ gr.update(interactive=True),
418
+ gr.update(interactive=False)
419
+ )
420
+
421
+ def update_training_ui():
422
+ training_state = load_training_state()
423
+
424
+ if training_state is None:
425
+ return (
426
+ 0, 0, 0, "00:00:00", None, None, "", None,
427
+ gr.update(value="Not started", elem_classes=["feedback-div", "info"])
428
+ )
429
+
430
+ # Calculate progress percentage
431
+ total_steps = training_state["total_steps"]
432
+ current_step = training_state["current_step"]
433
+ progress_pct = (current_step / total_steps * 100) if total_steps > 0 else 0
434
+
435
+ # Format elapsed time
436
+ hours, remainder = divmod(training_state["elapsed_time"], 3600)
437
+ minutes, seconds = divmod(remainder, 60)
438
+ time_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
439
+
440
+ # Update status message
441
+ if training_state["status"] == "completed":
442
+ status_msg = gr.update(value="✅ Training completed successfully", elem_classes=["feedback-div", "success"])
443
+ elif training_state["status"] == "error":
444
+ status_msg = gr.update(value=f"⚠️ Training error: {training_state['error']}", elem_classes=["feedback-div", "error"])
445
+ elif training_state["status"] == "stopped":
446
+ status_msg = gr.update(value="⚠️ Training stopped by user", elem_classes=["feedback-div", "error"])
447
+ else:
448
+ status_msg = gr.update(value="⏳ Training in progress...", elem_classes=["feedback-div", "info"])
449
+
450
+ return (
451
+ progress_pct,
452
+ training_state["current_epoch"],
453
+ current_step,
454
+ time_str,
455
+ training_state["loss_plot"],
456
+ training_state["eval_plot"],
457
+ training_state["log"],
458
+ training_state["samples"],
459
+ status_msg
460
+ )
461
+
462
+ start_btn.click(
463
+ start_training_handler,
464
+ inputs=[state],
465
+ outputs=[state, status, start_btn, stop_btn]
466
+ )
467
+
468
+ stop_btn.click(
469
+ stop_training_handler,
470
+ inputs=[state],
471
+ outputs=[state, status, start_btn, stop_btn]
472
+ )
473
+
474
+ # Remove problematic JavaScript loading approach
475
+ # Create a simple manual refresh button for compatibility
476
+ manual_refresh = gr.Button("Refresh Status", visible=True)
477
+ manual_refresh.click(
478
+ update_training_ui,
479
+ inputs=None,
480
+ outputs=[
481
+ progress, current_epoch, current_step, elapsed_time,
482
+ loss_plot, eval_plot, training_log, sample_outputs, status
483
+ ]
484
+ )
485
+
486
+ # Add auto-refresh functionality with HTML component
487
+ auto_refresh = gr.HTML("""
488
+ <script>
489
+ // Auto-refresh the UI every second
490
+ function setupAutoRefresh() {
491
+ setInterval(function() {
492
+ const refreshButton = document.querySelector('button:contains("Refresh Status")');
493
+ if (refreshButton) {
494
+ refreshButton.click();
495
+ }
496
+ }, 2000);
497
+ }
498
+
499
+ // Set up the auto-refresh when page loads
500
+ if (window.addEventListener) {
501
+ window.addEventListener('load', setupAutoRefresh, false);
502
+ }
503
+ </script>
504
+ <p style="margin-top: 5px; font-size: 0.8em; color: #666;">Auto-refreshing status every 2 seconds</p>
505
+ """)
506
+
507
+ # Initial UI update
508
+ demo.load(
509
+ update_training_ui,
510
+ inputs=None,
511
+ outputs=[
512
+ progress, current_epoch, current_step, elapsed_time,
513
+ loss_plot, eval_plot, training_log, sample_outputs, status
514
+ ]
515
+ )
516
+
517
+ with gr.Tab("Export Model"):
518
+ gr.Markdown("## Export your fine-tuned model")
519
+
520
+ with gr.Row():
521
+ with gr.Column():
522
+ export_format = gr.Radio(
523
+ ["PyTorch", "GGUF", "Safetensors"],
524
+ label="Export Format",
525
+ value="PyTorch"
526
+ )
527
+
528
+ quantization = gr.Dropdown(
529
+ ["None", "int8", "int4"],
530
+ label="Quantization (GGUF only)",
531
+ value="None",
532
+ interactive=True
533
+ )
534
+
535
+ model_name_input = gr.Textbox(
536
+ label="Model Name",
537
+ placeholder="my-fine-tuned-gemma",
538
+ value="my-fine-tuned-gemma"
539
+ )
540
+
541
+ output_dir = gr.Textbox(
542
+ label="Output Directory",
543
+ placeholder="Path to save the exported model",
544
+ value="./exports"
545
+ )
546
+
547
+ export_btn = gr.Button("Export Model", variant="primary")
548
+
549
+ with gr.Column():
550
+ export_info = gr.JSON(label="Export Information", visible=False)
551
+ export_status = gr.Markdown(
552
+ "",
553
+ elem_classes=["feedback-div"]
554
+ )
555
+ # Fix: Remove 'visible' parameter which is not supported in this Gradio version
556
+ export_progress = gr.Progress()
557
+
558
+ def export_model_handler(current_state, format, quant, name, out_dir):
559
+ if current_state.get("fine_tuned_model_path") is None:
560
+ return (
561
+ gr.update(value="⚠️ No fine-tuned model available. Please complete training first.",
562
+ elem_classes=["feedback-div", "error"]),
563
+ None
564
+ )
565
+
566
+ try:
567
+ # Actual export would be implemented in another function
568
+ export_path = os.path.join(out_dir, name)
569
+ os.makedirs(export_path, exist_ok=True)
570
+
571
+ export_info = {
572
+ "format": format,
573
+ "quantization": quant if format == "GGUF" else "None",
574
+ "model_name": name,
575
+ "export_path": export_path,
576
+ "model_size": "0.5 GB", # This would be calculated during actual export
577
+ "export_time": "00:01:23" # This would be measured during actual export
578
+ }
579
+
580
+ return (
581
+ gr.update(value=f"✅ Model exported successfully to {export_path}",
582
+ elem_classes=["feedback-div", "success"]),
583
+ export_info
584
+ )
585
+ except Exception as e:
586
+ return (
587
+ gr.update(value=f"⚠️ Error exporting model: {str(e)}",
588
+ elem_classes=["feedback-div", "error"]),
589
+ None
590
+ )
591
+
592
+ export_btn.click(
593
+ export_model_handler,
594
+ inputs=[state, export_format, quantization, model_name_input, output_dir],
595
+ # Update outputs list to remove reference to progress visibility
596
+ outputs=[export_status, export_info]
597
+ )
598
+
599
+ with gr.Tab("Documentation"):
600
+ gr.Markdown("""
601
+ # Gemma Fine-Tuning Documentation
602
+
603
+ ## Supported Models
604
+
605
+ This application supports fine-tuning the following Gemma models:
606
+
607
+ - google/gemma-2-2b-it
608
+ - google/gemma-2-9b-it
609
+ - google/gemma-2-27b-it
610
+
611
+ ## Dataset Format
612
+
613
+ Your dataset should follow one of these formats:
614
+
615
+ ### CSV
616
+ ```
617
+ prompt,completion
618
+ "What is the capital of France?","The capital of France is Paris."
619
+ "How does photosynthesis work?","Photosynthesis is the process..."
620
+ ```
621
+
622
+ ### JSONL
623
+ ```
624
+ {"prompt": "What is the capital of France?", "completion": "The capital of France is Paris."}
625
+ {"prompt": "How does photosynthesis work?", "completion": "Photosynthesis is the process..."}
626
+ ```
627
+
628
+ ### Plain Text
629
+ ```
630
+ What is the capital of France?
631
+ ###
632
+ The capital of France is Paris.
633
+ ###
634
+ How does photosynthesis work?
635
+ ###
636
+ Photosynthesis is the process...
637
+ ```
638
+
639
+ ## Fine-Tuning Parameters
640
+
641
+ ### Basic Parameters
642
+
643
+ - **Learning Rate**: Controls how quickly the model adapts to the training data. Typical values range from 1e-5 to 5e-5.
644
+ - **Batch Size**: Number of samples processed before model weights are updated. Higher values require more memory.
645
+ - **Number of Epochs**: Number of complete passes through the training dataset. More epochs can lead to better results but may cause overfitting.
646
+ - **Max Sequence Length**: Maximum length of input sequences. Longer sequences require more memory.
647
+
648
+ ### Advanced Parameters
649
+
650
+ - **Gradient Accumulation Steps**: Accumulate gradients over multiple batches to simulate larger batch size.
651
+ - **Warmup Steps**: Number of steps for learning rate warmup. Helps stabilize training in the early phases.
652
+ - **Weight Decay**: L2 regularization factor to prevent overfitting.
653
+ - **LoRA Parameters**: Controls the behavior of LoRA (Low-Rank Adaptation), a parameter-efficient fine-tuning technique.
654
+
655
+ ## Export Formats
656
+
657
+ - **PyTorch**: Standard PyTorch model format (.pt or .bin files with model architecture).
658
+ - **GGUF**: Compact format optimized for efficient inference (especially with llama.cpp).
659
+ - **Safetensors**: Safe format for storing tensors, preventing arbitrary code execution.
660
+
661
+ ## Quantization
662
+
663
+ Quantization reduces model size and increases inference speed at the cost of some accuracy:
664
+
665
+ - **None**: No quantization, full precision (usually FP16 or BF16).
666
+ - **int8**: 8-bit integer quantization, good balance of speed and accuracy.
667
+ - **int4**: 4-bit integer quantization, fastest but may reduce accuracy more significantly.
668
+ """)
669
 
670
  demo.launch()
data_processing.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import csv
4
+ import pandas as pd
5
+ import random
6
+
7
+ def validate_dataset(file_path, options):
8
+ """
9
+ Validates that a dataset file can be processed with the given options.
10
+
11
+ Args:
12
+ file_path: Path to the dataset file
13
+ options: Dictionary of processing options
14
+
15
+ Returns:
16
+ Tuple of (is_valid, message)
17
+ """
18
+ if not os.path.exists(file_path):
19
+ return False, f"File not found: {file_path}"
20
+
21
+ file_format = options.get("format", "").lower()
22
+
23
+ try:
24
+ if file_format == "csv":
25
+ # Validate CSV format
26
+ separator = options.get("csv_separator", ",")
27
+ prompt_col = options.get("csv_prompt_col", "prompt")
28
+ completion_col = options.get("csv_completion_col", "completion")
29
+
30
+ df = pd.read_csv(file_path, sep=separator)
31
+
32
+ if prompt_col not in df.columns:
33
+ return False, f"Prompt column '{prompt_col}' not found in CSV file"
34
+ if completion_col not in df.columns:
35
+ return False, f"Completion column '{completion_col}' not found in CSV file"
36
+
37
+ # Check for empty values
38
+ if df[prompt_col].isnull().any():
39
+ return False, "CSV file contains empty prompt values"
40
+ if df[completion_col].isnull().any():
41
+ return False, "CSV file contains empty completion values"
42
+
43
+ elif file_format == "jsonl":
44
+ # Validate JSONL format
45
+ prompt_key = options.get("jsonl_prompt_key", "prompt")
46
+ completion_key = options.get("jsonl_completion_key", "completion")
47
+
48
+ with open(file_path, 'r', encoding='utf-8') as f:
49
+ line_count = 0
50
+ for line in f:
51
+ line = line.strip()
52
+ if not line:
53
+ continue
54
+
55
+ data = json.loads(line)
56
+ line_count += 1
57
+
58
+ if prompt_key not in data:
59
+ return False, f"Prompt key '{prompt_key}' not found in JSONL at line {line_count}"
60
+ if completion_key not in data:
61
+ return False, f"Completion key '{completion_key}' not found in JSONL at line {line_count}"
62
+
63
+ if not data[prompt_key] or not isinstance(data[prompt_key], str):
64
+ return False, f"Invalid prompt value at line {line_count}"
65
+ if not data[completion_key] or not isinstance(data[completion_key], str):
66
+ return False, f"Invalid completion value at line {line_count}"
67
+
68
+ if line_count == 0:
69
+ return False, "JSONL file is empty"
70
+
71
+ elif file_format == "plain text":
72
+ # Validate plain text format
73
+ separator = options.get("text_separator", "###")
74
+
75
+ with open(file_path, 'r', encoding='utf-8') as f:
76
+ content = f.read()
77
+
78
+ parts = content.split(separator)
79
+ if len(parts) < 3: # Need at least one prompt and one completion
80
+ return False, f"Text file doesn't contain enough sections separated by '{separator}'"
81
+
82
+ # Check if there's an odd number of parts (should be prompt, completion, prompt, completion, ...)
83
+ if len(parts) % 2 == 0:
84
+ return False, f"Text file has an invalid number of sections separated by '{separator}'"
85
+
86
+ else:
87
+ return False, f"Unsupported format: {file_format}"
88
+
89
+ return True, "Dataset is valid"
90
+
91
+ except Exception as e:
92
+ return False, f"Error validating dataset: {str(e)}"
93
+
94
+ def process_dataset(file_path, options):
95
+ """
96
+ Processes a dataset file according to the given options.
97
+
98
+ Args:
99
+ file_path: Path to the dataset file
100
+ options: Dictionary of processing options
101
+
102
+ Returns:
103
+ Tuple of (processed_data, stats, preview)
104
+ """
105
+ file_format = options.get("format", "").lower()
106
+
107
+ if file_format == "csv":
108
+ return _process_csv(file_path, options)
109
+ elif file_format == "jsonl":
110
+ return _process_jsonl(file_path, options)
111
+ elif file_format == "plain text":
112
+ return _process_text(file_path, options)
113
+ else:
114
+ raise ValueError(f"Unsupported format: {file_format}")
115
+
116
+ def _process_csv(file_path, options):
117
+ """Process a CSV dataset file."""
118
+ separator = options.get("csv_separator", ",")
119
+ prompt_col = options.get("csv_prompt_col", "prompt")
120
+ completion_col = options.get("csv_completion_col", "completion")
121
+
122
+ df = pd.read_csv(file_path, sep=separator)
123
+
124
+ # Extract prompts and completions
125
+ data = []
126
+ for _, row in df.iterrows():
127
+ data.append({
128
+ "prompt": str(row[prompt_col]),
129
+ "completion": str(row[completion_col])
130
+ })
131
+
132
+ # Generate statistics
133
+ stats = {
134
+ "num_examples": len(data),
135
+ "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data),
136
+ "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data),
137
+ "format": "csv"
138
+ }
139
+
140
+ # Create a preview DataFrame (showing first 5 rows)
141
+ preview = df[[prompt_col, completion_col]].head(5)
142
+
143
+ return data, stats, preview
144
+
145
+ def _process_jsonl(file_path, options):
146
+ """Process a JSONL dataset file."""
147
+ prompt_key = options.get("jsonl_prompt_key", "prompt")
148
+ completion_key = options.get("jsonl_completion_key", "completion")
149
+
150
+ data = []
151
+ with open(file_path, 'r', encoding='utf-8') as f:
152
+ for line in f:
153
+ line = line.strip()
154
+ if not line:
155
+ continue
156
+
157
+ item = json.loads(line)
158
+ data.append({
159
+ "prompt": item[prompt_key],
160
+ "completion": item[completion_key]
161
+ })
162
+
163
+ # Generate statistics
164
+ stats = {
165
+ "num_examples": len(data),
166
+ "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data),
167
+ "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data),
168
+ "format": "jsonl"
169
+ }
170
+
171
+ # Create a preview DataFrame
172
+ preview_data = []
173
+ for i, item in enumerate(data[:5]):
174
+ preview_data.append({
175
+ "prompt": item["prompt"],
176
+ "completion": item["completion"]
177
+ })
178
+ preview = pd.DataFrame(preview_data)
179
+
180
+ return data, stats, preview
181
+
182
+ def _process_text(file_path, options):
183
+ """Process a plain text dataset file."""
184
+ separator = options.get("text_separator", "###")
185
+
186
+ with open(file_path, 'r', encoding='utf-8') as f:
187
+ content = f.read()
188
+
189
+ parts = content.split(separator)
190
+
191
+ data = []
192
+ for i in range(0, len(parts) - 1, 2):
193
+ prompt = parts[i].strip()
194
+ completion = parts[i + 1].strip()
195
+
196
+ if prompt and completion:
197
+ data.append({
198
+ "prompt": prompt,
199
+ "completion": completion
200
+ })
201
+
202
+ # Generate statistics
203
+ stats = {
204
+ "num_examples": len(data),
205
+ "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data),
206
+ "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data),
207
+ "format": "text"
208
+ }
209
+
210
+ # Create a preview DataFrame
211
+ preview_data = []
212
+ for i, item in enumerate(data[:5]):
213
+ preview_data.append({
214
+ "prompt": item["prompt"],
215
+ "completion": item["completion"]
216
+ })
217
+ preview = pd.DataFrame(preview_data)
218
+
219
+ return data, stats, preview
220
+
221
+ def format_for_training(dataset, tokenizer, max_length=512):
222
+ """
223
+ Formats a processed dataset for training with Gemma.
224
+
225
+ Args:
226
+ dataset: List of prompt/completion pairs
227
+ tokenizer: Tokenizer for the model
228
+ max_length: Maximum sequence length
229
+
230
+ Returns:
231
+ Dictionary of training data
232
+ """
233
+ input_ids = []
234
+ labels = []
235
+ attention_mask = []
236
+
237
+ for item in dataset:
238
+ prompt = item["prompt"]
239
+ completion = item["completion"]
240
+
241
+ # Format as the model expects
242
+ full_text = f"{prompt}{tokenizer.eos_token}{completion}{tokenizer.eos_token}"
243
+
244
+ # Tokenize
245
+ encoded = tokenizer(full_text, max_length=max_length, padding="max_length", truncation=True)
246
+
247
+ # For input_ids, we use the full sequence
248
+ input_ids.append(encoded["input_ids"])
249
+ attention_mask.append(encoded["attention_mask"])
250
+
251
+ # For labels, we set the prompt tokens to -100 so they're ignored in loss calculation
252
+ prompt_encoded = tokenizer(f"{prompt}{tokenizer.eos_token}", add_special_tokens=False)
253
+ prompt_length = len(prompt_encoded["input_ids"])
254
+
255
+ # Create label tensor: -100 for prompt tokens (ignored in loss), actual token IDs for completion
256
+ label = [-100] * prompt_length + encoded["input_ids"][prompt_length:]
257
+
258
+ # Pad to max_length
259
+ if len(label) < max_length:
260
+ label = label + [-100] * (max_length - len(label))
261
+ else:
262
+ label = label[:max_length]
263
+
264
+ labels.append(label)
265
+
266
+ return {
267
+ "input_ids": input_ids,
268
+ "attention_mask": attention_mask,
269
+ "labels": labels
270
+ }
271
+
272
+ def create_train_val_split(dataset, val_size=0.1, seed=42):
273
+ """
274
+ Splits a dataset into training and validation sets.
275
+
276
+ Args:
277
+ dataset: List of examples
278
+ val_size: Fraction of examples to use for validation
279
+ seed: Random seed for reproducibility
280
+
281
+ Returns:
282
+ Tuple of (train_dataset, val_dataset)
283
+ """
284
+ random.seed(seed)
285
+ random.shuffle(dataset)
286
+
287
+ val_count = max(1, int(len(dataset) * val_size))
288
+
289
+ val_dataset = dataset[:val_count]
290
+ train_dataset = dataset[val_count:]
291
+
292
+ return train_dataset, val_dataset
env_setup.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Environment setup to handle library conflicts and dependencies
3
+ """
4
+
5
+ import os
6
+ import logging
7
+
8
+ def setup_environment():
9
+ """Configure environment variables for the application"""
10
+
11
+ # Disable TensorFlow warnings and prevent it from being loaded
12
+ # This allows Transformers to work without TensorFlow dependencies
13
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Disable TensorFlow logging
14
+ os.environ["USE_TORCH"] = "1" # Tell Transformers to use PyTorch
15
+ os.environ["USE_TF"] = "0" # Tell Transformers not to use TensorFlow
16
+
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(levelname)s - %(message)s'
21
+ )
22
+
23
+ # Log environment settings
24
+ logging.info("Environment configured: PyTorch enabled, TensorFlow disabled")
25
+
26
+ return True
fine_tuning.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for fine-tuning Gemma models
3
+ """
4
+
5
+ import os
6
+ import time
7
+ import json
8
+ import threading
9
+ import torch
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import pandas as pd
13
+ from datetime import datetime
14
+ from transformers import (
15
+ AutoModelForCausalLM,
16
+ AutoTokenizer,
17
+ TrainingArguments,
18
+ Trainer,
19
+ DataCollatorForLanguageModeling
20
+ )
21
+ from peft import get_peft_model, LoraConfig, TaskType
22
+ from data_processing import create_train_val_split, format_for_training
23
+ from model_utils import load_model
24
+ from datasets import Dataset
25
+
26
+ # Global variable to store training state
27
+ _TRAINING_STATE = None
28
+
29
+ class TrainingThread(threading.Thread):
30
+ """Thread class for running training in the background."""
31
+
32
+ def __init__(self, model_name, dataset, params):
33
+ threading.Thread.__init__(self)
34
+ self.model_name = model_name
35
+ self.dataset = dataset
36
+ self.params = params
37
+ self.stop_flag = False
38
+ self.daemon = True # Thread will exit when main program exits
39
+
40
+ def run(self):
41
+ """Run the training process."""
42
+ try:
43
+ # Initialize training state
44
+ global _TRAINING_STATE
45
+ _TRAINING_STATE = {
46
+ "status": "initializing",
47
+ "current_epoch": 0,
48
+ "current_step": 0,
49
+ "total_steps": 0,
50
+ "elapsed_time": 0,
51
+ "loss_plot": None,
52
+ "eval_plot": None,
53
+ "log": "",
54
+ "samples": None,
55
+ "error": None
56
+ }
57
+
58
+ # Create output directory
59
+ output_dir = os.path.join("outputs", datetime.now().strftime("%Y%m%d_%H%M%S"))
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ # Load the model and tokenizer
63
+ model, tokenizer = load_model(self.model_name)
64
+
65
+ # Apply LoRA configuration
66
+ lora_config = LoraConfig(
67
+ r=self.params.get("lora_r", 16),
68
+ lora_alpha=self.params.get("lora_alpha", 32),
69
+ lora_dropout=self.params.get("lora_dropout", 0.05),
70
+ bias="none",
71
+ task_type=TaskType.CAUSAL_LM
72
+ )
73
+ model = get_peft_model(model, lora_config)
74
+
75
+ # Split dataset into train and validation
76
+ train_data, val_data = create_train_val_split(self.dataset)
77
+
78
+ # Format data for training
79
+ max_length = self.params.get("max_seq_length", 512)
80
+ train_formatted = format_for_training(train_data, tokenizer, max_length)
81
+ val_formatted = format_for_training(val_data, tokenizer, max_length)
82
+
83
+ # Convert to HF Datasets
84
+ train_dataset = Dataset.from_dict(train_formatted)
85
+ val_dataset = Dataset.from_dict(val_formatted)
86
+
87
+ # Create data collator
88
+ data_collator = DataCollatorForLanguageModeling(
89
+ tokenizer=tokenizer,
90
+ mlm=False
91
+ )
92
+
93
+ # Set up training arguments
94
+ batch_size = self.params.get("batch_size", 4)
95
+ gradient_accumulation_steps = self.params.get("gradient_accumulation_steps", 1)
96
+ num_epochs = self.params.get("num_epochs", 3)
97
+
98
+ # Calculate total steps
99
+ train_steps = len(train_dataset) // batch_size // gradient_accumulation_steps * num_epochs
100
+ _TRAINING_STATE["total_steps"] = train_steps
101
+
102
+ # Training arguments
103
+ training_args = TrainingArguments(
104
+ output_dir=output_dir,
105
+ learning_rate=self.params.get("learning_rate", 2e-5),
106
+ per_device_train_batch_size=batch_size,
107
+ per_device_eval_batch_size=batch_size,
108
+ gradient_accumulation_steps=gradient_accumulation_steps,
109
+ num_train_epochs=num_epochs,
110
+ weight_decay=self.params.get("weight_decay", 0.01),
111
+ warmup_steps=self.params.get("warmup_steps", 100),
112
+ logging_dir=os.path.join(output_dir, "logs"),
113
+ logging_steps=10,
114
+ evaluation_strategy="epoch",
115
+ save_strategy="epoch",
116
+ save_total_limit=2,
117
+ load_best_model_at_end=True,
118
+ report_to="none" # Disable wandb, tensorboard, etc.
119
+ )
120
+
121
+ # Custom callback for UI updates
122
+ class UICallback:
123
+ def __init__(self, thread):
124
+ self.thread = thread
125
+ self.start_time = time.time()
126
+ self.losses = []
127
+ self.eval_metrics = []
128
+ self.log_buffer = ""
129
+
130
+ def on_log(self, args, state, control, logs=None, **kwargs):
131
+ if self.thread.stop_flag:
132
+ control.should_training_stop = True
133
+ _TRAINING_STATE["status"] = "stopped"
134
+ return
135
+
136
+ if logs is None:
137
+ return
138
+
139
+ # Update training state
140
+ _TRAINING_STATE["elapsed_time"] = time.time() - self.start_time
141
+
142
+ # Handle training logs
143
+ if "loss" in logs:
144
+ _TRAINING_STATE["current_step"] = state.global_step
145
+ loss = logs["loss"]
146
+ self.losses.append((state.global_step, loss))
147
+
148
+ # Update loss plot
149
+ fig, ax = plt.subplots(figsize=(10, 6))
150
+ steps, losses = zip(*self.losses)
151
+ ax.plot(steps, losses)
152
+ ax.set_xlabel("Steps")
153
+ ax.set_ylabel("Loss")
154
+ ax.set_title("Training Loss")
155
+ ax.grid(True)
156
+ _TRAINING_STATE["loss_plot"] = fig
157
+
158
+ # Update log
159
+ log_entry = f"Step {state.global_step}: loss={loss:.4f}\n"
160
+ self.log_buffer += log_entry
161
+ _TRAINING_STATE["log"] = self.log_buffer
162
+
163
+ # Handle evaluation logs
164
+ if "eval_loss" in logs:
165
+ _TRAINING_STATE["current_epoch"] = state.epoch
166
+ eval_loss = logs["eval_loss"]
167
+ self.eval_metrics.append((state.epoch, eval_loss))
168
+
169
+ # Update eval plot
170
+ fig, ax = plt.subplots(figsize=(10, 6))
171
+ epochs, metrics = zip(*self.eval_metrics)
172
+ ax.plot(epochs, metrics)
173
+ ax.set_xlabel("Epochs")
174
+ ax.set_ylabel("Evaluation Loss")
175
+ ax.set_title("Validation Loss")
176
+ ax.grid(True)
177
+ _TRAINING_STATE["eval_plot"] = fig
178
+
179
+ # Generate sample outputs for visualization
180
+ sample_outputs = self.generate_samples(model, tokenizer)
181
+ _TRAINING_STATE["samples"] = sample_outputs
182
+
183
+ # Update log
184
+ log_entry = f"Epoch {state.epoch}: eval_loss={eval_loss:.4f}\n"
185
+ self.log_buffer += log_entry
186
+ _TRAINING_STATE["log"] = self.log_buffer
187
+
188
+ def generate_samples(self, model, tokenizer, num_samples=3):
189
+ """Generate sample outputs from the current model."""
190
+ # Get random samples from validation set
191
+ val_indices = np.random.choice(len(val_data), min(num_samples, len(val_data)), replace=False)
192
+ samples = [val_data[i] for i in val_indices]
193
+
194
+ results = []
195
+ for sample in samples:
196
+ prompt = sample["prompt"]
197
+ reference = sample["completion"]
198
+
199
+ # Generate text
200
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
201
+ with torch.no_grad():
202
+ outputs = model.generate(
203
+ **inputs,
204
+ max_new_tokens=100,
205
+ temperature=0.7,
206
+ num_return_sequences=1
207
+ )
208
+
209
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
210
+
211
+ # Remove the prompt from the generated text
212
+ if generated.startswith(prompt):
213
+ generated = generated[len(prompt):].strip()
214
+
215
+ results.append({
216
+ "Prompt": prompt,
217
+ "Generated Text": generated,
218
+ "Reference": reference
219
+ })
220
+
221
+ return pd.DataFrame(results)
222
+
223
+ # Create trainer
224
+ ui_callback = UICallback(self)
225
+
226
+ trainer = Trainer(
227
+ model=model,
228
+ args=training_args,
229
+ train_dataset=train_dataset,
230
+ eval_dataset=val_dataset,
231
+ data_collator=data_collator,
232
+ callbacks=[ui_callback]
233
+ )
234
+
235
+ # Update training state
236
+ _TRAINING_STATE["status"] = "training"
237
+
238
+ # Start training
239
+ trainer.train()
240
+
241
+ # Save final model
242
+ trainer.save_model(os.path.join(output_dir, "final"))
243
+ tokenizer.save_pretrained(os.path.join(output_dir, "final"))
244
+
245
+ # Update training state
246
+ _TRAINING_STATE["status"] = "completed"
247
+ _TRAINING_STATE["fine_tuned_model_path"] = os.path.join(output_dir, "final")
248
+
249
+ except Exception as e:
250
+ # Update training state with error
251
+ _TRAINING_STATE["status"] = "error"
252
+ _TRAINING_STATE["error"] = str(e)
253
+ print(f"Training error: {str(e)}")
254
+
255
+ def stop(self):
256
+ """Signal the thread to stop training."""
257
+ self.stop_flag = True
258
+
259
+ def start_fine_tuning(model_name, dataset, params):
260
+ """
261
+ Start the fine-tuning process in a background thread.
262
+
263
+ Args:
264
+ model_name: Name of the model to fine-tune
265
+ dataset: Processed dataset
266
+ params: Training parameters
267
+
268
+ Returns:
269
+ TrainingThread object
270
+ """
271
+ thread = TrainingThread(model_name, dataset, params)
272
+ thread.start()
273
+ return thread
274
+
275
+ def load_training_state():
276
+ """
277
+ Get the current training state.
278
+
279
+ Returns:
280
+ Dictionary with training state information
281
+ """
282
+ return _TRAINING_STATE
model_utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for handling Gemma models
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from huggingface_hub import login, HfApi
9
+
10
+ def get_available_models():
11
+ """
12
+ Returns a list of available Gemma models for fine-tuning.
13
+ """
14
+ return [
15
+ "google/gemma-2-2b-it",
16
+ "google/gemma-2-9b-it",
17
+ "google/gemma-2-27b-it"
18
+ ]
19
+
20
+ def load_model(model_name, token=None):
21
+ """
22
+ Loads a model from Hugging Face Hub.
23
+
24
+ Args:
25
+ model_name: Name of the model to load
26
+ token: Hugging Face token for access to gated models
27
+
28
+ Returns:
29
+ Tuple of (model, tokenizer)
30
+ """
31
+ if token:
32
+ login(token)
33
+
34
+ # Set appropriate device
35
+ if torch.cuda.is_available():
36
+ device = "cuda"
37
+ elif torch.backends.mps.is_available():
38
+ device = "mps" # For Apple Silicon
39
+ else:
40
+ device = "cpu"
41
+
42
+ print(f"Loading model {model_name} on {device}...")
43
+
44
+ # Load model with appropriate parameters based on device and model size
45
+ model_size = model_name.split("-")[2]
46
+ if device == "cuda":
47
+ # For CUDA devices, optimize based on model size and available memory
48
+ if model_size in ["2b", "7b"]:
49
+ # Smaller models can be loaded in BF16
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ model_name,
52
+ torch_dtype=torch.bfloat16,
53
+ device_map="auto"
54
+ )
55
+ else:
56
+ # Larger models may need additional optimizations
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_name,
59
+ torch_dtype=torch.bfloat16,
60
+ device_map="auto",
61
+ load_in_8bit=True
62
+ )
63
+ elif device == "cpu":
64
+ # For CPU, use FP32 but load 8-bit for larger models to conserve memory
65
+ if model_size in ["2b"]:
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_name,
68
+ device_map={"": device}
69
+ )
70
+ else:
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ model_name,
73
+ device_map={"": device},
74
+ load_in_8bit=True
75
+ )
76
+ else: # MPS (Apple Silicon)
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ model_name,
79
+ torch_dtype=torch.float16,
80
+ device_map={"": device}
81
+ )
82
+
83
+ # Load tokenizer
84
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
85
+
86
+ return model, tokenizer
87
+
88
+ def export_model(model_path, output_dir, model_name, format="pytorch", quantization=None):
89
+ """
90
+ Exports a fine-tuned model to the specified format.
91
+
92
+ Args:
93
+ model_path: Path to the fine-tuned model
94
+ output_dir: Directory to save the exported model
95
+ model_name: Name for the exported model
96
+ format: Export format ("pytorch", "gguf", or "safetensors")
97
+ quantization: Quantization level for GGUF format
98
+
99
+ Returns:
100
+ Dictionary with export information
101
+ """
102
+ if not os.path.exists(model_path):
103
+ raise ValueError(f"Model path '{model_path}' does not exist")
104
+
105
+ os.makedirs(output_dir, exist_ok=True)
106
+ export_path = os.path.join(output_dir, model_name)
107
+ os.makedirs(export_path, exist_ok=True)
108
+
109
+ # Load the model and merge LoRA weights if applicable
110
+ model = AutoModelForCausalLM.from_pretrained(model_path)
111
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
112
+
113
+ # Handle different export formats
114
+ if format.lower() == "pytorch":
115
+ # Export as PyTorch model
116
+ model.save_pretrained(export_path)
117
+ tokenizer.save_pretrained(export_path)
118
+
119
+ elif format.lower() == "safetensors":
120
+ # Export as safetensors
121
+ model.save_pretrained(export_path, safe_serialization=True)
122
+ tokenizer.save_pretrained(export_path)
123
+
124
+ elif format.lower() == "gguf":
125
+ # For GGUF, we'd typically use a conversion script
126
+ # This is simplified; in practice you'd use specific tools for GGUF conversion
127
+ if quantization is not None and quantization.lower() != "none":
128
+ # Command for quantized GGUF conversion would go here
129
+ # In practice, use llama.cpp or similar tools
130
+ pass
131
+ else:
132
+ # Command for standard GGUF conversion would go here
133
+ pass
134
+
135
+ else:
136
+ raise ValueError(f"Unsupported export format: {format}")
137
+
138
+ # Calculate model size
139
+ model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
140
+ model_size_gb = model_size_bytes / (1024**3)
141
+
142
+ return {
143
+ "format": format.lower(),
144
+ "quantization": quantization if format.lower() == "gguf" else "None",
145
+ "model_name": model_name,
146
+ "export_path": export_path,
147
+ "model_size": f"{model_size_gb:.2f} GB"
148
+ }
149
+
150
+ def push_to_hub(model_path, repo_name, token):
151
+ """
152
+ Pushes a fine-tuned model to Hugging Face Hub.
153
+
154
+ Args:
155
+ model_path: Path to the fine-tuned model
156
+ repo_name: Name for the repository on Hugging Face Hub
157
+ token: Hugging Face token
158
+
159
+ Returns:
160
+ URL of the uploaded model
161
+ """
162
+ if not os.path.exists(model_path):
163
+ raise ValueError(f"Model path '{model_path}' does not exist")
164
+
165
+ login(token)
166
+
167
+ # Load the model and merge LoRA weights if applicable
168
+ model = AutoModelForCausalLM.from_pretrained(model_path)
169
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
170
+
171
+ # Push to hub
172
+ model.push_to_hub(repo_name)
173
+ tokenizer.push_to_hub(repo_name)
174
+
175
+ # Get the model URL
176
+ api = HfApi()
177
+ model_url = f"https://huggingface.co/{repo_name}"
178
+
179
+ return model_url
requirements.txt CHANGED
@@ -1 +1,10 @@
1
- gradio
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.20.1
2
+ torch>=2.0.0
3
+ transformers>=4.36.0
4
+ peft>=0.5.0
5
+ pandas>=2.0.0
6
+ numpy>=1.24.0
7
+ matplotlib>=3.7.0
8
+ datasets>=2.14.0
9
+ accelerate>=0.20.0
10
+ sentencepiece>=0.1.99