Spaces:
Runtime error
Runtime error
| """ | |
| AutoTrain Gradio MCP Server - All-in-One | |
| This single Gradio app: | |
| 1. Provides a web interface for managing AutoTrain jobs | |
| 2. Automatically exposes MCP tools at /gradio_api/mcp/sse | |
| 3. Handles all AutoTrain operations directly (no FastAPI needed) | |
| """ | |
| import os | |
| import json | |
| import uuid | |
| import threading | |
| from datetime import datetime | |
| from typing import List, Dict, Any | |
| import socket | |
| import gradio as gr | |
| import pandas as pd | |
| import wandb | |
| from autotrain.project import AutoTrainProject | |
| from autotrain.params import ( | |
| LLMTrainingParams, | |
| TextClassificationParams, | |
| ImageClassificationParams, | |
| ) | |
| # Simple JSON-based storage (replace with SQLite if needed) | |
| RUNS_FILE = "training_runs.json" | |
| WANDB_PROJECT = os.environ.get("WANDB_PROJECT", "autotrain-mcp") | |
| def load_runs() -> List[Dict[str, Any]]: | |
| """Load training runs from JSON file""" | |
| if os.path.exists(RUNS_FILE): | |
| try: | |
| with open(RUNS_FILE, "r") as f: | |
| return json.load(f) | |
| except (json.JSONDecodeError, IOError): | |
| return [] | |
| return [] | |
| def save_runs(runs: List[Dict[str, Any]]): | |
| """Save training runs to JSON file""" | |
| with open(RUNS_FILE, "w") as f: | |
| json.dump(runs, f, indent=2) | |
| def get_status_emoji(status: str) -> str: | |
| """Get emoji for training status""" | |
| emoji_map = { | |
| "pending": "β³", | |
| "running": "π", | |
| "completed": "β ", | |
| "failed": "β", | |
| "cancelled": "βΉοΈ", | |
| } | |
| return emoji_map.get(status.lower(), "β") | |
| def create_autotrain_params( | |
| task: str, | |
| base_model: str, | |
| project_name: str, | |
| dataset_path: str, | |
| epochs: int, | |
| batch_size: int, | |
| learning_rate: float, | |
| push_to_hub: bool, | |
| hub_repo_id: str = "", | |
| **kwargs, | |
| ): | |
| """Create AutoTrain parameter object based on task type""" | |
| # Hub configuration | |
| hub_config = {} | |
| if push_to_hub: | |
| hub_config = { | |
| "push_to_hub": True, | |
| "username": os.environ.get("HF_USERNAME", ""), | |
| "token": os.environ.get("HF_TOKEN", ""), | |
| } | |
| # If custom repo_id is provided, use it; otherwise use project_name | |
| if hub_repo_id: | |
| hub_config["repo_id"] = hub_repo_id | |
| common_params = { | |
| "model": base_model, | |
| "project_name": project_name, | |
| "data_path": dataset_path, | |
| "train_split": kwargs.get("train_split", "train"), | |
| "valid_split": kwargs.get("valid_split"), | |
| "epochs": epochs, | |
| "batch_size": batch_size, | |
| "lr": learning_rate, | |
| "log": "wandb", | |
| # Required defaults | |
| "warmup_ratio": 0.1, | |
| "gradient_accumulation": 1, | |
| "optimizer": "adamw_torch", | |
| "scheduler": "linear", | |
| "weight_decay": 0.01, | |
| "max_grad_norm": 1.0, | |
| "seed": 42, | |
| "logging_steps": 10, | |
| "auto_find_batch_size": False, | |
| "mixed_precision": "no", | |
| "save_total_limit": 1, | |
| "eval_strategy": "epoch", | |
| **hub_config, # Add hub configuration | |
| } | |
| if task == "text-classification": | |
| return TextClassificationParams( | |
| **common_params, | |
| text_column=kwargs.get("text_column", "text"), | |
| target_column=kwargs.get("target_column", "label"), | |
| max_seq_length=kwargs.get("max_seq_length", 128), | |
| early_stopping_patience=3, | |
| early_stopping_threshold=0.01, | |
| ) | |
| elif task.startswith("llm-"): | |
| trainer_map = { | |
| "llm-sft": "sft", | |
| "llm-dpo": "dpo", | |
| "llm-orpo": "orpo", | |
| "llm-reward": "reward", | |
| } | |
| # For LLM tasks, exclude some parameters that don't apply | |
| llm_params = { | |
| k: v | |
| for k, v in common_params.items() | |
| if k not in ["early_stopping_patience", "early_stopping_threshold"] | |
| } | |
| return LLMTrainingParams( | |
| **llm_params, | |
| text_column=kwargs.get("text_column", "messages"), | |
| block_size=kwargs.get("block_size", 2048), | |
| peft=kwargs.get("use_peft", True), | |
| quantization=kwargs.get("quantization", "int4"), | |
| trainer=trainer_map[task], | |
| chat_template="tokenizer", | |
| # LLM-specific defaults | |
| add_eos_token=True, | |
| model_max_length=2048, | |
| padding="right", | |
| use_flash_attention_2=False, | |
| disable_gradient_checkpointing=False, | |
| target_modules="all-linear", | |
| merge_adapter=False, | |
| lora_r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| model_ref=None, | |
| dpo_beta=0.1, | |
| max_prompt_length=512, | |
| max_completion_length=1024, | |
| prompt_text_column="prompt", | |
| rejected_text_column="rejected", | |
| unsloth=False, | |
| distributed_backend="accelerate", | |
| ) | |
| elif task == "image-classification": | |
| return ImageClassificationParams( | |
| **common_params, | |
| image_column=kwargs.get("image_column", "image"), | |
| target_column=kwargs.get("target_column", "label"), | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported task type: {task}") | |
| def run_training_background(run_id: str, params: Any, backend: str): | |
| """Run training job in background thread""" | |
| runs = load_runs() | |
| # Update status to running | |
| for run in runs: | |
| if run["run_id"] == run_id: | |
| run["status"] = "running" | |
| run["started_at"] = datetime.utcnow().isoformat() | |
| break | |
| save_runs(runs) | |
| try: | |
| # Set W&B environment variables for AutoTrain to use | |
| os.environ["WANDB_PROJECT"] = WANDB_PROJECT | |
| print(f"Starting real training for run {run_id}") | |
| print(f"Model: {params.model}") | |
| print(f"Dataset: {params.data_path}") | |
| print(f"Backend: {backend}") | |
| # Create AutoTrain project - this will handle W&B internally | |
| project = AutoTrainProject(params=params, backend=backend, process=True) | |
| # Actually run the training - this blocks until completion | |
| print(f"Executing training job for run {run_id}...") | |
| result = project.create() | |
| print(f"Training completed successfully for run {run_id}") | |
| print(f"Result: {result}") | |
| # Get the actual W&B run URL after training starts | |
| wandb_url = f"https://wandb.ai/{WANDB_PROJECT}" | |
| try: | |
| if wandb.run is not None: | |
| wandb_url = wandb.run.url | |
| print(f"Got actual W&B URL: {wandb_url}") | |
| else: | |
| print("No active W&B run found, using default URL") | |
| except Exception as e: | |
| print(f"Could not get W&B URL: {e}") | |
| # Update with actual W&B URL | |
| runs = load_runs() | |
| for run in runs: | |
| if run["run_id"] == run_id: | |
| run["wandb_url"] = wandb_url | |
| break | |
| save_runs(runs) | |
| # Update status to completed | |
| runs = load_runs() | |
| for run in runs: | |
| if run["run_id"] == run_id: | |
| run["status"] = "completed" | |
| run["completed_at"] = datetime.utcnow().isoformat() | |
| if result: | |
| run["result"] = str(result) | |
| break | |
| save_runs(runs) | |
| except Exception as e: | |
| print(f"Training failed for run {run_id}: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| # Update status to failed | |
| runs = load_runs() | |
| for run in runs: | |
| if run["run_id"] == run_id: | |
| run["status"] = "failed" | |
| run["error_message"] = str(e) | |
| run["completed_at"] = datetime.utcnow().isoformat() | |
| break | |
| save_runs(runs) | |
| # MCP Tool Functions (these automatically become MCP tools) | |
| def start_training_job( | |
| task: str = "text-classification", | |
| project_name: str = "test-project", | |
| base_model: str = "distilbert-base-uncased", | |
| dataset_path: str = "imdb", | |
| epochs: str = "1", | |
| batch_size: str = "8", | |
| learning_rate: str = "2e-5", | |
| backend: str = "local", | |
| push_to_hub: str = "false", | |
| hub_repo_id: str = "", | |
| ) -> str: | |
| """ | |
| Start a new AutoTrain training job. | |
| Args: | |
| task: Type of training task (text-classification, llm-sft, | |
| llm-dpo, llm-orpo, image-classification) | |
| project_name: Name for the training project | |
| base_model: Base model from Hugging Face Hub | |
| (e.g., distilbert-base-uncased) | |
| dataset_path: Dataset path or HF dataset name (e.g., imdb) | |
| epochs: Number of training epochs (default: 3) | |
| batch_size: Training batch size (default: 16) | |
| learning_rate: Learning rate for training (default: 2e-5) | |
| backend: Training backend to use (default: local) | |
| push_to_hub: Whether to push final model to Hub (true/false) | |
| hub_repo_id: Custom repository ID for Hub (optional) | |
| Returns: | |
| Status message with run ID and details | |
| """ | |
| try: | |
| # Convert string parameters | |
| epochs_int = int(epochs) | |
| batch_size_int = int(batch_size) | |
| learning_rate_float = float(learning_rate) | |
| push_to_hub_bool = push_to_hub.lower() == "true" | |
| # Generate run ID | |
| run_id = str(uuid.uuid4()) | |
| # Create run record | |
| run_data = { | |
| "run_id": run_id, | |
| "project_name": project_name, | |
| "task": task, | |
| "base_model": base_model, | |
| "dataset_path": dataset_path, | |
| "status": "pending", | |
| "created_at": datetime.utcnow().isoformat(), | |
| "updated_at": datetime.utcnow().isoformat(), | |
| "push_to_hub": push_to_hub_bool, | |
| "hub_repo_id": hub_repo_id, | |
| "config": { | |
| "task": task, | |
| "epochs": epochs_int, | |
| "batch_size": batch_size_int, | |
| "learning_rate": learning_rate_float, | |
| "backend": backend, | |
| "push_to_hub": push_to_hub_bool, | |
| "hub_repo_id": hub_repo_id, | |
| }, | |
| } | |
| # Save to storage | |
| runs = load_runs() | |
| runs.append(run_data) | |
| save_runs(runs) | |
| # Create AutoTrain parameters | |
| params = create_autotrain_params( | |
| task=task, | |
| base_model=base_model, | |
| project_name=project_name, | |
| dataset_path=dataset_path, | |
| epochs=epochs_int, | |
| batch_size=batch_size_int, | |
| learning_rate=learning_rate_float, | |
| push_to_hub=push_to_hub_bool, | |
| hub_repo_id=hub_repo_id, | |
| ) | |
| # Start training in background | |
| thread = threading.Thread( | |
| target=run_training_background, args=(run_id, params, backend) | |
| ) | |
| thread.daemon = True | |
| thread.start() | |
| # Build result message | |
| result_msg = f"""β Training job submitted successfully! | |
| Run ID: {run_id} | |
| Project: {project_name} | |
| Task: {task} | |
| Model: {base_model} | |
| Dataset: {dataset_path} | |
| Configuration: | |
| β’ Epochs: {epochs} | |
| β’ Batch Size: {batch_size} | |
| β’ Learning Rate: {learning_rate} | |
| β’ Backend: {backend}""" | |
| if push_to_hub_bool: | |
| final_repo = hub_repo_id if hub_repo_id else project_name | |
| result_msg += f""" | |
| β’ Push to Hub: β Enabled | |
| β’ Repository: {final_repo} | |
| β’ Requires: HF_USERNAME and HF_TOKEN environment variables""" | |
| else: | |
| result_msg += "\nβ’ Push to Hub: β Disabled" | |
| result_msg += """ | |
| π Monitor progress: | |
| β’ Gradio UI: http://localhost:7860 | |
| β’ W&B tracking will be available once training starts | |
| π‘ Use get_training_runs() to check status""" | |
| return result_msg | |
| except Exception as e: | |
| return f"β Error submitting job: {str(e)}" | |
| def get_training_runs(limit: str = "20", status: str = "") -> str: | |
| """ | |
| Get list of training runs with their status and details. | |
| Args: | |
| limit: Maximum number of runs to return (default: 20) | |
| status: Filter by run status (pending, running, completed, | |
| failed, cancelled) | |
| Returns: | |
| Formatted list of training runs with status and links | |
| """ | |
| try: | |
| runs = load_runs() | |
| # Filter by status if provided | |
| if status: | |
| runs = [run for run in runs if run.get("status") == status] | |
| # Apply limit | |
| runs = runs[-int(limit) :] | |
| if not runs: | |
| return "No training runs found. Start a new training job to see it here!" | |
| runs_text = f"π Training Runs (showing {len(runs)}):\n\n" | |
| for run in reversed(runs): # Show newest first | |
| status_emoji = get_status_emoji(run["status"]) | |
| # Format run display with line break | |
| run_display = ( | |
| f"{status_emoji} **{run['project_name']}** ({run['run_id'][:8]}...)" | |
| ) | |
| runs_text += f"{run_display}\n" | |
| runs_text += f" Task: {run['task']}\n" | |
| runs_text += f" Model: {run['base_model']}\n" | |
| runs_text += f" Status: {run['status'].title()}\n" | |
| runs_text += f" Created: {run['created_at']}\n" | |
| if run.get("wandb_url"): | |
| runs_text += f" π W&B: {run['wandb_url']}\n" | |
| if run.get("error_message"): | |
| runs_text += f" β Error: {run['error_message']}\n" | |
| runs_text += "\n" | |
| return runs_text | |
| except Exception as e: | |
| return f"β Error fetching runs: {str(e)}" | |
| def get_run_details(run_id: str) -> str: | |
| """ | |
| Get detailed information about a specific training run. | |
| Args: | |
| run_id: ID of the training run (can be partial ID) | |
| Returns: | |
| Detailed run information including config and status | |
| """ | |
| try: | |
| runs = load_runs() | |
| # Find run by full or partial ID | |
| found_run = None | |
| for run in runs: | |
| if run["run_id"] == run_id or run["run_id"].startswith(run_id): | |
| found_run = run | |
| break | |
| if not found_run: | |
| return f"β Training run {run_id} not found" | |
| run = found_run | |
| details_text = f"""π Training Run Details | |
| **Run ID:** {run["run_id"]} | |
| **Project:** {run["project_name"]} | |
| **Task:** {run["task"]} | |
| **Model:** {run["base_model"]} | |
| **Dataset:** {run["dataset_path"]} | |
| **Status:** {run["status"].title()} | |
| **Timestamps:** | |
| β’ Created: {run["created_at"]} | |
| β’ Updated: {run.get("updated_at", "N/A")}""" | |
| if run.get("started_at"): | |
| details_text += f"\nβ’ Started: {run['started_at']}" | |
| if run.get("completed_at"): | |
| details_text += f"\nβ’ Completed: {run['completed_at']}" | |
| if run.get("wandb_url"): | |
| details_text += f"\n\nπ **W&B Dashboard:** {run['wandb_url']}" | |
| if run.get("error_message"): | |
| details_text += f"\n\nβ **Error:** {run['error_message']}" | |
| if run.get("config"): | |
| config = run["config"] | |
| details_text += "\n\nβοΈ **Training Configuration:**" | |
| details_text += f"\nβ’ Epochs: {config.get('epochs')}" | |
| details_text += f"\nβ’ Batch Size: {config.get('batch_size')}" | |
| details_text += f"\nβ’ Learning Rate: {config.get('learning_rate')}" | |
| details_text += f"\nβ’ Backend: {config.get('backend')}" | |
| # Hub configuration | |
| if config.get("push_to_hub"): | |
| details_text += "\nβ’ Push to Hub: β Enabled" | |
| if config.get("hub_repo_id"): | |
| details_text += f"\nβ’ Hub Repository: {config.get('hub_repo_id')}" | |
| else: | |
| details_text += ( | |
| f"\nβ’ Hub Repository: {run['project_name']} (default)" | |
| ) | |
| else: | |
| details_text += "\nβ’ Push to Hub: β Disabled" | |
| return details_text | |
| except Exception as e: | |
| return f"β Error fetching run details: {str(e)}" | |
| def get_task_recommendations( | |
| task: str = "text-classification", dataset_size: str = "medium" | |
| ) -> str: | |
| """ | |
| Get training recommendations for a specific task type. | |
| Args: | |
| task: Task type (text-classification, llm-sft, image-classification) | |
| dataset_size: Size of dataset (small, medium, large) | |
| Returns: | |
| Recommended models, parameters, and best practices | |
| """ | |
| recommendations = { | |
| "text-classification": { | |
| "models": ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"], | |
| "params": {"batch_size": 16, "learning_rate": 2e-5, "epochs": 3}, | |
| "backends": ["local", "spaces-t4-small"], | |
| "notes": [ | |
| "Good for sentiment analysis", | |
| "Works well with IMDB, AG News datasets", | |
| ], | |
| }, | |
| "llm-sft": { | |
| "models": [ | |
| "microsoft/DialoGPT-medium", | |
| "HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
| ], | |
| "params": {"batch_size": 1, "learning_rate": 1e-5, "epochs": 3}, | |
| "backends": ["spaces-t4-medium", "spaces-a10g-large"], | |
| "notes": ["Use PEFT for efficiency", "Ensure proper chat formatting"], | |
| }, | |
| "image-classification": { | |
| "models": ["google/vit-base-patch16-224", "microsoft/resnet-50"], | |
| "params": {"batch_size": 32, "learning_rate": 2e-5, "epochs": 5}, | |
| "backends": ["local", "spaces-t4-small"], | |
| "notes": ["Ensure images are preprocessed", "Works with CIFAR, ImageNet"], | |
| }, | |
| } | |
| rec = recommendations.get( | |
| task, | |
| { | |
| "models": [], | |
| "params": {}, | |
| "backends": ["local"], | |
| "notes": ["No specific recommendations available"], | |
| }, | |
| ) | |
| rec_text = f"""π― Training Recommendations for {task.title()} \ | |
| ({dataset_size} dataset) | |
| **Recommended Models:** | |
| {chr(10).join(f"β’ {model}" for model in rec["models"])} | |
| **Recommended Parameters:** | |
| {chr(10).join(f"β’ {k}: {v}" for k, v in rec["params"].items())} | |
| **Backend Suggestions:** | |
| {chr(10).join(f"β’ {backend}" for backend in rec["backends"])} | |
| **Best Practices:** | |
| {chr(10).join(f"β’ {note}" for note in rec["notes"])}""" | |
| return rec_text | |
| def get_system_status(random_string: str = "") -> str: | |
| """ | |
| Get AutoTrain system status and capabilities. | |
| Returns: | |
| System status, available tasks, backends, and statistics | |
| """ | |
| try: | |
| runs = load_runs() | |
| # Calculate stats | |
| total_runs = len(runs) | |
| running_runs = len([r for r in runs if r.get("status") == "running"]) | |
| completed_runs = len([r for r in runs if r.get("status") == "completed"]) | |
| failed_runs = len([r for r in runs if r.get("status") == "failed"]) | |
| wandb_api_status = ( | |
| "β Configured" if os.environ.get("WANDB_API_KEY") else "β Missing" | |
| ) | |
| wandb_metrics_status = ( | |
| "β Enabled" | |
| if os.environ.get("WANDB_API_KEY") | |
| else "β System metrics only" | |
| ) | |
| status_text = f"""## βοΈ System Status | |
| ### π Run Statistics | |
| | Metric | Count | | |
| |--------|-------| | |
| | **Server Status** | β Running | | |
| | **Total Runs** | {total_runs} | | |
| | **Active Runs** | {running_runs} | | |
| | **Completed Runs** | {completed_runs} | | |
| | **Failed Runs** | {failed_runs} | | |
| ### π‘ Access Points | |
| | Service | URL | | |
| |---------|-----| | |
| | **Gradio UI** | http://SPACE_URL | | |
| | **MCP Server** | http://SPACE_URL/gradio_api/mcp/sse | | |
| | **MCP Schema** | http://SPACE_URL/gradio_api/mcp/schema | | |
| ### π οΈ W&B Integration | |
| | Component | Status | | |
| |-----------|--------| | |
| | **Project** | {WANDB_PROJECT} | | |
| | **API Key** | {wandb_api_status} | | |
| | **Training Metrics** | {wandb_metrics_status} | | |
| π‘ **Note:** Set WANDB_API_KEY for complete training metrics logging""" | |
| return status_text | |
| except Exception as e: | |
| return f"β Error getting system status: {str(e)}" | |
| def refresh_data(random_string: str = "") -> str: | |
| """Refresh data for UI updates""" | |
| return "Data refreshed successfully" | |
| def load_initial_data(random_string: str = "") -> str: | |
| """Load initial data for the application""" | |
| return "Initial data loaded successfully" | |
| # Web UI Functions | |
| def fetch_runs_for_ui(): | |
| """Fetch runs for the web interface table""" | |
| try: | |
| runs = load_runs() | |
| if not runs: | |
| return pd.DataFrame( | |
| { | |
| "Status": [], | |
| "W&B Link": [], | |
| "Project": [], | |
| "Task": [], | |
| "Model": [], | |
| "Created": [], | |
| "Run ID": [], | |
| } | |
| ) | |
| data = [] | |
| for run in reversed(runs): # Newest first | |
| wandb_link = "" | |
| if run.get("wandb_url"): | |
| wandb_link = f"[π W&B Run]({run['wandb_url']})" | |
| data.append( | |
| { | |
| "Status": f"{get_status_emoji(run['status'])} {run['status'].title()}", | |
| "W&B Link": wandb_link, | |
| "Project": run["project_name"], | |
| "Task": run["task"].replace("-", " ").title(), | |
| "Model": run["base_model"], | |
| "Created": run["created_at"][:16].replace("T", " "), | |
| "Run ID": run["run_id"][:8] + "...", | |
| } | |
| ) | |
| return pd.DataFrame(data, datatype="markdown") | |
| except Exception as e: | |
| return pd.DataFrame({"Error": [f"Failed to fetch runs: {str(e)}"]}) | |
| def submit_training_job_ui( | |
| task, | |
| project_name, | |
| base_model, | |
| dataset_path, | |
| epochs, | |
| batch_size, | |
| learning_rate, | |
| backend, | |
| push_to_hub, | |
| hub_repo_id, | |
| ): | |
| """Submit training job from web UI""" | |
| if not all([task, project_name, base_model, dataset_path]): | |
| return "β Please fill in all required fields", fetch_runs_for_ui() | |
| result = start_training_job( | |
| task=task, | |
| project_name=project_name, | |
| base_model=base_model, | |
| dataset_path=dataset_path, | |
| epochs=str(epochs), | |
| batch_size=str(batch_size), | |
| learning_rate=str(learning_rate), | |
| backend=backend, | |
| push_to_hub=str(push_to_hub).lower(), | |
| hub_repo_id=hub_repo_id, | |
| ) | |
| return result, fetch_runs_for_ui() | |
| # Create Gradio Interface | |
| with gr.Blocks( | |
| title="AutoTrain Gradio MCP Server", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| """, | |
| ) as app: | |
| gr.Markdown(""" | |
| # π AutoTrain MCP Server | |
| Get your AI models to train your AI models! | |
| This space is an MCP server that you can use in Claude Desktop, Cursor, VSCode, etc to train your AI models. | |
| :warning: To train models you with need to duplicate this space! | |
| **MCP Server**: AI assistants can use tools at http://SPACE_URL/gradio_api/mcp/sse | |
| Connect to it like this: | |
| ```javascript | |
| { | |
| "mcpServers": { | |
| "autotrain": { | |
| "url": "http://SPACE_URL/gradio_api/mcp/sse", | |
| "headers": {"Authorization": "Bearer <YOUR-HUGGING-FACE-TOKEN>"} | |
| } | |
| } | |
| } | |
| ``` | |
| Or like this for Claude Desktop: | |
| ```javascript | |
| { | |
| "mcpServers": { | |
| "autotrain": { | |
| "command": "npx", | |
| "args": [ | |
| "mcp-remote", | |
| "http://SPACE_URL/gradio_api/mcp/sse", | |
| "--header", | |
| "Authorization: Bearer <YOUR-HUGGING-FACE-TOKEN>" | |
| ] | |
| } | |
| } | |
| } | |
| ``` | |
| """) | |
| with gr.Tabs(): | |
| # Dashboard Tab | |
| with gr.Tab("π Training Runs"): | |
| with gr.Row(): | |
| runs_table = gr.Dataframe( | |
| value=fetch_runs_for_ui(), interactive=False, datatype="markdown" | |
| ) | |
| with gr.Row(): | |
| refresh_btn = gr.Button("π Refresh", variant="secondary") | |
| with gr.Tab("π§ System Status"): | |
| stats = gr.Markdown(value=get_system_status()) | |
| # MCP Tools Tab | |
| with gr.Tab("π§ MCP Tools"): | |
| gr.Markdown("## MCP Tool Testing Interface") | |
| gr.Markdown("These tools are exposed via MCP for Claude Desktop") | |
| gr.Interface( | |
| fn=get_system_status, | |
| inputs=[], | |
| outputs=gr.Textbox(label="System Status"), | |
| title="get_system_status", | |
| description="Get AutoTrain system status and capabilities", | |
| ) | |
| gr.Interface( | |
| fn=get_training_runs, | |
| inputs=[ | |
| gr.Textbox(label="limit", value="20"), | |
| gr.Textbox(label="status", value=""), | |
| ], | |
| outputs=gr.Textbox(label="Training Runs"), | |
| title="get_training_runs", | |
| description="Get list of training runs with status", | |
| ) | |
| gr.Interface( | |
| fn=start_training_job, | |
| inputs=[ | |
| gr.Textbox(label="task", value="text-classification"), | |
| gr.Textbox(label="project_name", value="test-project"), | |
| gr.Textbox(label="base_model", value="distilbert-base-uncased"), | |
| gr.Textbox(label="dataset_path", value="imdb"), | |
| gr.Textbox(label="epochs", value="1"), | |
| gr.Textbox(label="batch_size", value="8"), | |
| gr.Textbox(label="learning_rate", value="2e-5"), | |
| gr.Textbox(label="backend", value="local"), | |
| gr.Textbox(label="push_to_hub", value="false"), | |
| gr.Textbox(label="hub_repo_id", placeholder="your-repo-id"), | |
| ], | |
| outputs=gr.Textbox(label="Training Job Result"), | |
| title="start_training_job", | |
| description="Start a new AutoTrain training job", | |
| ) | |
| gr.Interface( | |
| fn=get_run_details, | |
| inputs=gr.Textbox( | |
| label="run_id", placeholder="Enter run ID or first 8 chars" | |
| ), | |
| outputs=gr.Textbox(label="Run Details"), | |
| title="get_run_details", | |
| description="Get detailed information about a training run", | |
| ) | |
| gr.Interface( | |
| fn=get_task_recommendations, | |
| inputs=[ | |
| gr.Textbox(label="task", value="text-classification"), | |
| gr.Textbox(label="dataset_size", value="medium"), | |
| ], | |
| outputs=gr.Textbox(label="Recommendations"), | |
| title="get_task_recommendations", | |
| description="Get training recommendations for a task", | |
| ) | |
| # Event handlers with proper function names (not lambda) | |
| def refresh_ui_data(): | |
| return fetch_runs_for_ui(), get_system_status() | |
| def load_initial_ui_data(): | |
| return fetch_runs_for_ui(), get_system_status() | |
| refresh_btn.click( | |
| fn=refresh_ui_data, | |
| outputs=[runs_table, stats], | |
| ) | |
| # Load initial data | |
| app.load( | |
| fn=load_initial_ui_data, | |
| outputs=[runs_table, stats], | |
| ) | |
| # Helper to find an available port | |
| def _find_available_port(start_port: int = 7860, max_tries: int = 20) -> int: | |
| """Return the first available port starting from `start_port`.""" | |
| port = start_port | |
| for _ in range(max_tries): | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| try: | |
| s.bind(("0.0.0.0", port)) | |
| return port # Port is free | |
| except OSError: | |
| port += 1 # Try next port | |
| # If no port found, let OS pick one | |
| return 0 | |
| if __name__ == "__main__": | |
| chosen_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860")) | |
| try: | |
| chosen_port = _find_available_port(chosen_port) | |
| except Exception: | |
| # Fallback to OS-assigned port if something goes wrong | |
| chosen_port = 0 | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=chosen_port, | |
| mcp_server=True, # Enable MCP server functionality | |
| ) | |