Spaces:
Runtime error
Runtime error
burtenshaw
commited on
Commit
Β·
061fdd4
1
Parent(s):
df41d15
fix wandb integration
Browse files
app.py
CHANGED
|
@@ -9,7 +9,6 @@ This single Gradio app:
|
|
| 9 |
|
| 10 |
import os
|
| 11 |
import json
|
| 12 |
-
import time
|
| 13 |
import uuid
|
| 14 |
import threading
|
| 15 |
from datetime import datetime
|
|
@@ -18,7 +17,6 @@ import socket
|
|
| 18 |
|
| 19 |
import gradio as gr
|
| 20 |
import pandas as pd
|
| 21 |
-
import wandb
|
| 22 |
from autotrain.project import AutoTrainProject
|
| 23 |
from autotrain.params import (
|
| 24 |
LLMTrainingParams,
|
|
@@ -189,24 +187,19 @@ def run_training_background(run_id: str, params: Any, backend: str):
|
|
| 189 |
save_runs(runs)
|
| 190 |
|
| 191 |
try:
|
| 192 |
-
#
|
| 193 |
-
|
| 194 |
-
project=WANDB_PROJECT,
|
| 195 |
-
name=f"{params.project_name}-{int(time.time())}",
|
| 196 |
-
tags=["autotrain", "mcp"],
|
| 197 |
-
config={
|
| 198 |
-
"base_model": params.model,
|
| 199 |
-
"dataset": params.data_path,
|
| 200 |
-
"epochs": params.epochs,
|
| 201 |
-
"batch_size": params.batch_size,
|
| 202 |
-
"learning_rate": params.lr,
|
| 203 |
-
"backend": backend,
|
| 204 |
-
},
|
| 205 |
-
)
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
# Update with W&B URL
|
| 212 |
runs = load_runs()
|
|
@@ -216,14 +209,12 @@ def run_training_background(run_id: str, params: Any, backend: str):
|
|
| 216 |
break
|
| 217 |
save_runs(runs)
|
| 218 |
|
| 219 |
-
#
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
print(f"Training started for run {run_id} with job ID: {job_id}")
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
|
| 228 |
# Update status to completed
|
| 229 |
runs = load_runs()
|
|
@@ -231,13 +222,16 @@ def run_training_background(run_id: str, params: Any, backend: str):
|
|
| 231 |
if run["run_id"] == run_id:
|
| 232 |
run["status"] = "completed"
|
| 233 |
run["completed_at"] = datetime.utcnow().isoformat()
|
|
|
|
|
|
|
| 234 |
break
|
| 235 |
save_runs(runs)
|
| 236 |
|
| 237 |
-
wandb.finish()
|
| 238 |
-
|
| 239 |
except Exception as e:
|
| 240 |
print(f"Training failed for run {run_id}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
# Update status to failed
|
| 243 |
runs = load_runs()
|
|
@@ -249,9 +243,6 @@ def run_training_background(run_id: str, params: Any, backend: str):
|
|
| 249 |
break
|
| 250 |
save_runs(runs)
|
| 251 |
|
| 252 |
-
if wandb.run:
|
| 253 |
-
wandb.finish()
|
| 254 |
-
|
| 255 |
|
| 256 |
# MCP Tool Functions (these automatically become MCP tools)
|
| 257 |
def start_training_job(
|
|
@@ -633,13 +624,19 @@ def get_system_status(random_string: str = "") -> str:
|
|
| 633 |
}
|
| 634 |
|
| 635 |
π‘ **Access Points:**
|
| 636 |
-
β’ Gradio UI: http://
|
| 637 |
-
β’ MCP Server: http://
|
| 638 |
-
β’ MCP Schema: http://
|
| 639 |
|
| 640 |
π οΈ **W&B Integration:**
|
| 641 |
β’ Project: {WANDB_PROJECT}
|
| 642 |
-
β’
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
|
| 644 |
return status_text
|
| 645 |
|
|
@@ -864,8 +861,8 @@ with gr.Blocks(
|
|
| 864 |
|
| 865 |
This Gradio app automatically serves as an MCP server.
|
| 866 |
|
| 867 |
-
**MCP Endpoint:** `http://
|
| 868 |
-
**MCP Schema:** `http://
|
| 869 |
|
| 870 |
### Available MCP Tools:
|
| 871 |
|
|
@@ -875,6 +872,24 @@ with gr.Blocks(
|
|
| 875 |
- `get_task_recommendations` - Get training recommendations
|
| 876 |
- `get_system_status` - Check system status
|
| 877 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
### π€ Hugging Face Hub Integration:
|
| 879 |
|
| 880 |
To push models to the Hub, set these environment variables:
|
|
@@ -906,6 +921,7 @@ with gr.Blocks(
|
|
| 906 |
|
| 907 |
Total Runs: {len(load_runs())}
|
| 908 |
W&B Project: {WANDB_PROJECT}
|
|
|
|
| 909 |
Hub Auth: {"β
Configured" if os.environ.get("HF_TOKEN") else "β Missing HF_TOKEN"}
|
| 910 |
""")
|
| 911 |
|
|
|
|
| 9 |
|
| 10 |
import os
|
| 11 |
import json
|
|
|
|
| 12 |
import uuid
|
| 13 |
import threading
|
| 14 |
from datetime import datetime
|
|
|
|
| 17 |
|
| 18 |
import gradio as gr
|
| 19 |
import pandas as pd
|
|
|
|
| 20 |
from autotrain.project import AutoTrainProject
|
| 21 |
from autotrain.params import (
|
| 22 |
LLMTrainingParams,
|
|
|
|
| 187 |
save_runs(runs)
|
| 188 |
|
| 189 |
try:
|
| 190 |
+
# Set W&B environment variables for AutoTrain to use
|
| 191 |
+
os.environ["WANDB_PROJECT"] = WANDB_PROJECT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
print(f"Starting real training for run {run_id}")
|
| 194 |
+
print(f"Model: {params.model}")
|
| 195 |
+
print(f"Dataset: {params.data_path}")
|
| 196 |
+
print(f"Backend: {backend}")
|
| 197 |
+
|
| 198 |
+
# Create AutoTrain project - this will handle W&B internally
|
| 199 |
+
project = AutoTrainProject(params=params, backend=backend, process=True)
|
| 200 |
+
|
| 201 |
+
# Generate approximate W&B URL
|
| 202 |
+
wandb_url = f"https://wandb.ai/{WANDB_PROJECT}"
|
| 203 |
|
| 204 |
# Update with W&B URL
|
| 205 |
runs = load_runs()
|
|
|
|
| 209 |
break
|
| 210 |
save_runs(runs)
|
| 211 |
|
| 212 |
+
# Actually run the training - this blocks until completion
|
| 213 |
+
print(f"Executing training job for run {run_id}...")
|
| 214 |
+
result = project.create()
|
|
|
|
|
|
|
| 215 |
|
| 216 |
+
print(f"Training completed successfully for run {run_id}")
|
| 217 |
+
print(f"Result: {result}")
|
| 218 |
|
| 219 |
# Update status to completed
|
| 220 |
runs = load_runs()
|
|
|
|
| 222 |
if run["run_id"] == run_id:
|
| 223 |
run["status"] = "completed"
|
| 224 |
run["completed_at"] = datetime.utcnow().isoformat()
|
| 225 |
+
if result:
|
| 226 |
+
run["result"] = str(result)
|
| 227 |
break
|
| 228 |
save_runs(runs)
|
| 229 |
|
|
|
|
|
|
|
| 230 |
except Exception as e:
|
| 231 |
print(f"Training failed for run {run_id}: {str(e)}")
|
| 232 |
+
import traceback
|
| 233 |
+
|
| 234 |
+
traceback.print_exc()
|
| 235 |
|
| 236 |
# Update status to failed
|
| 237 |
runs = load_runs()
|
|
|
|
| 243 |
break
|
| 244 |
save_runs(runs)
|
| 245 |
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
# MCP Tool Functions (these automatically become MCP tools)
|
| 248 |
def start_training_job(
|
|
|
|
| 624 |
}
|
| 625 |
|
| 626 |
π‘ **Access Points:**
|
| 627 |
+
β’ Gradio UI: http://SPACE_URL
|
| 628 |
+
β’ MCP Server: http://SPACE_URL/gradio_api/mcp/sse
|
| 629 |
+
β’ MCP Schema: http://SPACE_URL/gradio_api/mcp/schema
|
| 630 |
|
| 631 |
π οΈ **W&B Integration:**
|
| 632 |
β’ Project: {WANDB_PROJECT}
|
| 633 |
+
β’ API Key: {"β
Configured" if os.environ.get("WANDB_API_KEY") else "β Missing"}
|
| 634 |
+
β’ Training Metrics: {
|
| 635 |
+
"β
Enabled"
|
| 636 |
+
if os.environ.get("WANDB_API_KEY")
|
| 637 |
+
else "β System metrics only"
|
| 638 |
+
}
|
| 639 |
+
β’ Set WANDB_API_KEY for complete training metrics logging"""
|
| 640 |
|
| 641 |
return status_text
|
| 642 |
|
|
|
|
| 861 |
|
| 862 |
This Gradio app automatically serves as an MCP server.
|
| 863 |
|
| 864 |
+
**MCP Endpoint:** `http://SPACE_URL/gradio_api/mcp/sse`
|
| 865 |
+
**MCP Schema:** `http://SPACE_URL/gradio_api/mcp/schema`
|
| 866 |
|
| 867 |
### Available MCP Tools:
|
| 868 |
|
|
|
|
| 872 |
- `get_task_recommendations` - Get training recommendations
|
| 873 |
- `get_system_status` - Check system status
|
| 874 |
|
| 875 |
+
### π Weights & Biases Integration:
|
| 876 |
+
|
| 877 |
+
For **complete training metrics** (loss, accuracy, etc.), set:
|
| 878 |
+
|
| 879 |
+
```bash
|
| 880 |
+
export WANDB_API_KEY="your-wandb-api-key"
|
| 881 |
+
export WANDB_PROJECT="autotrain-mcp" # Optional: custom project name
|
| 882 |
+
```
|
| 883 |
+
|
| 884 |
+
Get your API key from: https://wandb.ai/authorize
|
| 885 |
+
|
| 886 |
+
**What gets logged by AutoTrain:**
|
| 887 |
+
- β
Training/validation loss
|
| 888 |
+
- β
Learning rate schedule
|
| 889 |
+
- β
Gradient norms
|
| 890 |
+
- β
Model checkpoints
|
| 891 |
+
- β
System metrics (GPU, CPU, memory)
|
| 892 |
+
|
| 893 |
### π€ Hugging Face Hub Integration:
|
| 894 |
|
| 895 |
To push models to the Hub, set these environment variables:
|
|
|
|
| 921 |
|
| 922 |
Total Runs: {len(load_runs())}
|
| 923 |
W&B Project: {WANDB_PROJECT}
|
| 924 |
+
W&B Auth: {"β
Configured" if os.environ.get("WANDB_API_KEY") else "β Missing WANDB_API_KEY"}
|
| 925 |
Hub Auth: {"β
Configured" if os.environ.get("HF_TOKEN") else "β Missing HF_TOKEN"}
|
| 926 |
""")
|
| 927 |
|