burtenshaw commited on
Commit
061fdd4
Β·
1 Parent(s): df41d15

fix wandb integration

Browse files
Files changed (1) hide show
  1. app.py +53 -37
app.py CHANGED
@@ -9,7 +9,6 @@ This single Gradio app:
9
 
10
  import os
11
  import json
12
- import time
13
  import uuid
14
  import threading
15
  from datetime import datetime
@@ -18,7 +17,6 @@ import socket
18
 
19
  import gradio as gr
20
  import pandas as pd
21
- import wandb
22
  from autotrain.project import AutoTrainProject
23
  from autotrain.params import (
24
  LLMTrainingParams,
@@ -189,24 +187,19 @@ def run_training_background(run_id: str, params: Any, backend: str):
189
  save_runs(runs)
190
 
191
  try:
192
- # Initialize W&B
193
- wandb_run = wandb.init(
194
- project=WANDB_PROJECT,
195
- name=f"{params.project_name}-{int(time.time())}",
196
- tags=["autotrain", "mcp"],
197
- config={
198
- "base_model": params.model,
199
- "dataset": params.data_path,
200
- "epochs": params.epochs,
201
- "batch_size": params.batch_size,
202
- "learning_rate": params.lr,
203
- "backend": backend,
204
- },
205
- )
206
 
207
- wandb_url = (
208
- wandb_run.url if wandb_run.url else f"https://wandb.ai/{WANDB_PROJECT}"
209
- )
 
 
 
 
 
 
 
210
 
211
  # Update with W&B URL
212
  runs = load_runs()
@@ -216,14 +209,12 @@ def run_training_background(run_id: str, params: Any, backend: str):
216
  break
217
  save_runs(runs)
218
 
219
- # Create and start AutoTrain project
220
- project = AutoTrainProject(params=params, backend=backend, process=True)
221
- job_id = project.create()
222
-
223
- print(f"Training started for run {run_id} with job ID: {job_id}")
224
 
225
- # For demo purposes, simulate training completion after a short delay
226
- time.sleep(10) # In real implementation, monitor actual training
227
 
228
  # Update status to completed
229
  runs = load_runs()
@@ -231,13 +222,16 @@ def run_training_background(run_id: str, params: Any, backend: str):
231
  if run["run_id"] == run_id:
232
  run["status"] = "completed"
233
  run["completed_at"] = datetime.utcnow().isoformat()
 
 
234
  break
235
  save_runs(runs)
236
 
237
- wandb.finish()
238
-
239
  except Exception as e:
240
  print(f"Training failed for run {run_id}: {str(e)}")
 
 
 
241
 
242
  # Update status to failed
243
  runs = load_runs()
@@ -249,9 +243,6 @@ def run_training_background(run_id: str, params: Any, backend: str):
249
  break
250
  save_runs(runs)
251
 
252
- if wandb.run:
253
- wandb.finish()
254
-
255
 
256
  # MCP Tool Functions (these automatically become MCP tools)
257
  def start_training_job(
@@ -633,13 +624,19 @@ def get_system_status(random_string: str = "") -> str:
633
  }
634
 
635
  πŸ’‘ **Access Points:**
636
- β€’ Gradio UI: http://localhost:7860
637
- β€’ MCP Server: http://localhost:7860/gradio_api/mcp/sse
638
- β€’ MCP Schema: http://localhost:7860/gradio_api/mcp/schema
639
 
640
  πŸ› οΈ **W&B Integration:**
641
  β€’ Project: {WANDB_PROJECT}
642
- β€’ Set WANDB_PROJECT environment variable to customize"""
 
 
 
 
 
 
643
 
644
  return status_text
645
 
@@ -864,8 +861,8 @@ with gr.Blocks(
864
 
865
  This Gradio app automatically serves as an MCP server.
866
 
867
- **MCP Endpoint:** `http://localhost:7860/gradio_api/mcp/sse`
868
- **MCP Schema:** `http://localhost:7860/gradio_api/mcp/schema`
869
 
870
  ### Available MCP Tools:
871
 
@@ -875,6 +872,24 @@ with gr.Blocks(
875
  - `get_task_recommendations` - Get training recommendations
876
  - `get_system_status` - Check system status
877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
  ### πŸ€— Hugging Face Hub Integration:
879
 
880
  To push models to the Hub, set these environment variables:
@@ -906,6 +921,7 @@ with gr.Blocks(
906
 
907
  Total Runs: {len(load_runs())}
908
  W&B Project: {WANDB_PROJECT}
 
909
  Hub Auth: {"βœ… Configured" if os.environ.get("HF_TOKEN") else "❌ Missing HF_TOKEN"}
910
  """)
911
 
 
9
 
10
  import os
11
  import json
 
12
  import uuid
13
  import threading
14
  from datetime import datetime
 
17
 
18
  import gradio as gr
19
  import pandas as pd
 
20
  from autotrain.project import AutoTrainProject
21
  from autotrain.params import (
22
  LLMTrainingParams,
 
187
  save_runs(runs)
188
 
189
  try:
190
+ # Set W&B environment variables for AutoTrain to use
191
+ os.environ["WANDB_PROJECT"] = WANDB_PROJECT
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ print(f"Starting real training for run {run_id}")
194
+ print(f"Model: {params.model}")
195
+ print(f"Dataset: {params.data_path}")
196
+ print(f"Backend: {backend}")
197
+
198
+ # Create AutoTrain project - this will handle W&B internally
199
+ project = AutoTrainProject(params=params, backend=backend, process=True)
200
+
201
+ # Generate approximate W&B URL
202
+ wandb_url = f"https://wandb.ai/{WANDB_PROJECT}"
203
 
204
  # Update with W&B URL
205
  runs = load_runs()
 
209
  break
210
  save_runs(runs)
211
 
212
+ # Actually run the training - this blocks until completion
213
+ print(f"Executing training job for run {run_id}...")
214
+ result = project.create()
 
 
215
 
216
+ print(f"Training completed successfully for run {run_id}")
217
+ print(f"Result: {result}")
218
 
219
  # Update status to completed
220
  runs = load_runs()
 
222
  if run["run_id"] == run_id:
223
  run["status"] = "completed"
224
  run["completed_at"] = datetime.utcnow().isoformat()
225
+ if result:
226
+ run["result"] = str(result)
227
  break
228
  save_runs(runs)
229
 
 
 
230
  except Exception as e:
231
  print(f"Training failed for run {run_id}: {str(e)}")
232
+ import traceback
233
+
234
+ traceback.print_exc()
235
 
236
  # Update status to failed
237
  runs = load_runs()
 
243
  break
244
  save_runs(runs)
245
 
 
 
 
246
 
247
  # MCP Tool Functions (these automatically become MCP tools)
248
  def start_training_job(
 
624
  }
625
 
626
  πŸ’‘ **Access Points:**
627
+ β€’ Gradio UI: http://SPACE_URL
628
+ β€’ MCP Server: http://SPACE_URL/gradio_api/mcp/sse
629
+ β€’ MCP Schema: http://SPACE_URL/gradio_api/mcp/schema
630
 
631
  πŸ› οΈ **W&B Integration:**
632
  β€’ Project: {WANDB_PROJECT}
633
+ β€’ API Key: {"βœ… Configured" if os.environ.get("WANDB_API_KEY") else "❌ Missing"}
634
+ β€’ Training Metrics: {
635
+ "βœ… Enabled"
636
+ if os.environ.get("WANDB_API_KEY")
637
+ else "❌ System metrics only"
638
+ }
639
+ β€’ Set WANDB_API_KEY for complete training metrics logging"""
640
 
641
  return status_text
642
 
 
861
 
862
  This Gradio app automatically serves as an MCP server.
863
 
864
+ **MCP Endpoint:** `http://SPACE_URL/gradio_api/mcp/sse`
865
+ **MCP Schema:** `http://SPACE_URL/gradio_api/mcp/schema`
866
 
867
  ### Available MCP Tools:
868
 
 
872
  - `get_task_recommendations` - Get training recommendations
873
  - `get_system_status` - Check system status
874
 
875
+ ### πŸ“Š Weights & Biases Integration:
876
+
877
+ For **complete training metrics** (loss, accuracy, etc.), set:
878
+
879
+ ```bash
880
+ export WANDB_API_KEY="your-wandb-api-key"
881
+ export WANDB_PROJECT="autotrain-mcp" # Optional: custom project name
882
+ ```
883
+
884
+ Get your API key from: https://wandb.ai/authorize
885
+
886
+ **What gets logged by AutoTrain:**
887
+ - βœ… Training/validation loss
888
+ - βœ… Learning rate schedule
889
+ - βœ… Gradient norms
890
+ - βœ… Model checkpoints
891
+ - βœ… System metrics (GPU, CPU, memory)
892
+
893
  ### πŸ€— Hugging Face Hub Integration:
894
 
895
  To push models to the Hub, set these environment variables:
 
921
 
922
  Total Runs: {len(load_runs())}
923
  W&B Project: {WANDB_PROJECT}
924
+ W&B Auth: {"βœ… Configured" if os.environ.get("WANDB_API_KEY") else "❌ Missing WANDB_API_KEY"}
925
  Hub Auth: {"βœ… Configured" if os.environ.get("HF_TOKEN") else "❌ Missing HF_TOKEN"}
926
  """)
927