import json import logging import os import pathlib import threading import time from datetime import datetime, timedelta from urllib.parse import quote import psutil import requests from backend.are import get_are_url, start_are_process_and_session_lite from backend.cleanup import cleanup from backend.globals import STORAGE_PATH from backend.iframe import validate_mcp_file from backend.session import UserSession from flask import Flask, jsonify, request, send_from_directory # Ensure storage directory exists os.makedirs(STORAGE_PATH, exist_ok=True) AUTH_SESSION_MANAGEMENT = {} SESSION_MANAGEMENT = {} # Serve the static frontend and expose a minimal API app = Flask( __name__, static_folder=os.path.join(os.path.dirname(__file__), "frontend", "build"), static_url_path="", ) logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") logger = logging.getLogger(__name__) def cleanup_old_sessions() -> None: """Clean up sessions that are older than 2 hours.""" try: current_time = datetime.now() sessions_to_remove = [] for username, session in SESSION_MANAGEMENT.items(): try: # Parse the session start time start_time = datetime.strptime( session.start_time, "%Y-%m-%d %H:%M:%S.%f" ) session_age = current_time - start_time # Check if session is older than 2 hours if session_age > timedelta(hours=2): logger.info( f"Session {session.sid} for user {username} " f"is {session_age} old, marking for cleanup" ) sessions_to_remove.append(username) except (ValueError, AttributeError) as e: logger.warning( f"Could not parse start time for session " f"{session.sid} (user: {username}): {e}" ) # If we can't parse the time, assume it's old and clean it up sessions_to_remove.append(username) # Clean up old sessions for username in sessions_to_remove: if username in SESSION_MANAGEMENT: session = SESSION_MANAGEMENT[username] logger.info( f"Cleaning up old session {session.sid} " f"for user {username}" ) try: cleanup(session) del SESSION_MANAGEMENT[username] logger.info( f"Successfully cleaned up old session " f"{session.sid} for user {username}" ) except Exception as e: logger.error( f"Error cleaning up old session " f"{session.sid} for user {username}: {e}" ) # Remove from SESSION_MANAGEMENT even if cleanup failed # to prevent accumulation of broken sessions try: del SESSION_MANAGEMENT[username] except KeyError: pass if sessions_to_remove: logger.info(f"Cleaned up {len(sessions_to_remove)} old sessions") except Exception as e: logger.error(f"Error during old session cleanup: {e}") def cleanup_session_async(user_session: UserSession) -> None: """Run cleanup in the background to avoid blocking the main thread.""" if user_session is None: return def run_cleanup(): try: session_id = user_session.sid logger.info(f"Starting background cleanup for session {session_id}") cleanup(user_session) logger.info(f"Background cleanup completed for session {session_id}") # Also clean up any other old sessions while we're at it logger.info("Checking for old sessions to clean up...") cleanup_old_sessions() except Exception as e: session_id = getattr(user_session, "sid", "unknown") logger.error( f"Error during background cleanup for session " f"{session_id}: {e}" ) # Start cleanup in a separate thread cleanup_thread = threading.Thread(target=run_cleanup, daemon=True) cleanup_thread.start() def get_session_from_cookie(cookie): # Possible cookie session names for session_name in [ "session", "spaces-jwt", "sessionid", "JSESSIONID", "connect.sid", ]: try: session = cookie[session_name] return session except Exception: continue return None @app.get("/") def index(): """Serve the main HTML document.""" sign_in_info = request.args.get("__sign", default=None, type=str) cookie_session = get_session_from_cookie(request.cookies) if sign_in_info is not None and cookie_session is not None: AUTH_SESSION_MANAGEMENT[cookie_session] = sign_in_info logger.info(f"Filled sign for session {cookie_session}: {sign_in_info}") return send_from_directory(app.static_folder, "index.html") @app.get("/demo-mcp.json") def demo_mcp(): """Serve the demo MCP file.""" # Try serving from the built frontend first (production) try: return send_from_directory(app.static_folder, "demo-mcp.json") except FileNotFoundError: # Fall back to the public directory (development) try: public_folder = os.path.join( os.path.dirname(__file__), "frontend", "public" ) return send_from_directory(public_folder, "demo-mcp.json") except FileNotFoundError: logger.error("demo-mcp.json not found in either build or public directory") return jsonify({"error": "demo-mcp.json not found"}), 404 @app.get("/api/models/") def get_models_for_provider(provider): """Fetch available models for a given provider from Hugging Face API.""" if provider == "llama-api": # Model IDs from https://llama.developer.meta.com/docs/models/ llama_models = [ "Llama-4-Maverick-17B-128E-Instruct-FP8", "Llama-4-Scout-17B-16E-Instruct-FP8", "Llama-3.3-70B-Instruct", "Llama-3.3-8B-Instruct", "Cerebras-Llama-4-Maverick-17B-128E-Instruct", "Cerebras-Llama-4-Scout-17B-16E-Instruct", "Groq-Llama-4-Maverick-17B-128E-Instruct", ] return jsonify({"models": llama_models}), 200 try: # Map provider to the correct API parameter with proper URL encoding encoded_provider = quote(provider) # Fetch models with image-text-to-text pipeline tag url_image_text = f"https://huggingface.co/api/models?pipeline_tag=image-text-to-text&inference_provider={encoded_provider}" response_image_text = requests.get(url_image_text, timeout=10) # Fetch models with text-generation pipeline tag url_text_gen = f"https://huggingface.co/api/models?pipeline_tag=text-generation&inference_provider={encoded_provider}" response_text_gen = requests.get(url_text_gen, timeout=10) # Check if both requests were successful if response_image_text.status_code != 200: logger.error( f"Failed to fetch image-text-to-text models for provider {provider}: " f"{response_image_text.status_code}" ) return ( jsonify( { "error": "Failed to fetch image-text-to-text models", "status": response_image_text.status_code, } ), 500, ) if response_text_gen.status_code != 200: logger.error( f"Failed to fetch text-generation models for provider {provider}: " f"{response_text_gen.status_code}" ) return ( jsonify( { "error": "Failed to fetch text-generation models", "status": response_text_gen.status_code, } ), 500, ) # Parse responses and merge results image_text_models = response_image_text.json() text_gen_models = response_text_gen.json() # Extract model IDs from both responses image_text_ids = [ model.get("id") for model in image_text_models if model.get("id") ] text_gen_ids = [model.get("id") for model in text_gen_models if model.get("id")] # Merge and deduplicate model IDs model_ids = list(set(image_text_ids + text_gen_ids)) model_ids.sort() # Sort the models alphabetically logger.info( f"Fetched {len(image_text_ids)} image-text-to-text models and {len(text_gen_ids)} text-generation models for provider {provider} (total: {len(model_ids)} unique models)" ) return jsonify({"models": model_ids}), 200 except requests.RequestException as e: logger.error(f"Network error when fetching models for provider {provider}: {e}") return jsonify({"error": "Network error", "detail": str(e)}), 500 except Exception as e: logger.error( f"Unexpected error when fetching models for provider " f"{provider}: {e}" ) return jsonify({"error": "Internal error", "detail": str(e)}), 500 @app.get("/api/processes") def list_python_processes(): # Check for key query parameter key = request.args.get("key") if not key: return jsonify({"error": "Unauthorized access"}), 403 # Check if key matches OWNER_SECRET environment variable owner_secret = os.environ.get("OWNER_SECRET") if not owner_secret: return (jsonify({"error": "Server configuration error"}), 500) if key != owner_secret: return jsonify({"error": "Unauthorized access"}), 403 try: python_processes = [] # Iterate through all running processes using psutil for proc in psutil.process_iter(): try: # Get process info pinfo = proc.as_dict( attrs=[ "pid", "ppid", "name", "username", "status", "create_time", "cpu_percent", "memory_percent", "memory_info", "cmdline", ] ) # Check if this is a Python process process_name = pinfo["name"].lower() cmdline = " ".join(pinfo["cmdline"]) if pinfo["cmdline"] else "" if ( "python" in process_name or "python" in cmdline.lower() or cmdline.endswith(".py") ): # Convert create_time to readable format create_time = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(pinfo["create_time"]) ) # Format memory info memory_info = pinfo.get("memory_info") rss_mb = (memory_info.rss / 1024 / 1024) if memory_info else 0 vms_mb = (memory_info.vms / 1024 / 1024) if memory_info else 0 process_info = { "pid": pinfo["pid"], "ppid": pinfo["ppid"], "name": pinfo["name"], "username": pinfo.get("username", "unknown"), "status": pinfo["status"], "cpu_percent": round(pinfo.get("cpu_percent", 0), 2), "memory_percent": round(pinfo.get("memory_percent", 0), 2), "memory_rss_mb": round(rss_mb, 2), "memory_vms_mb": round(vms_mb, 2), "create_time": create_time, "cmdline": ( cmdline[:200] + "..." if len(cmdline) > 200 else cmdline ), } python_processes.append(process_info) except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): # Process may have terminated or we don't have permission continue except Exception as e: logger.warning(f"Error processing process {proc.pid}: {e}") continue # Sort by PID for consistent ordering python_processes.sort(key=lambda x: x["pid"]) return ( jsonify({"processes": python_processes, "count": len(python_processes)}), 200, ) except Exception as e: logger.error(f"Unexpected error listing processes: {e}") return (jsonify({"error": "Internal server error"}), 500) @app.get("/api/sessions") def list_active_sessions(): # Check for key query parameter key = request.args.get("key") if not key: return jsonify({"error": "Unauthorized access"}), 403 # Check if key matches OWNER_SECRET environment variable owner_secret = os.environ.get("OWNER_SECRET") if not owner_secret: return (jsonify({"error": "Server configuration error"}), 500) if key != owner_secret: return jsonify({"error": "Unauthorized access"}), 403 try: active_sessions = [] current_time = time.time() for username, session in SESSION_MANAGEMENT.items(): try: # Calculate session duration start_timestamp = time.mktime( time.strptime(session.start_time, "%Y-%m-%d %H:%M:%S.%f") ) duration_seconds = current_time - start_timestamp duration_hours = duration_seconds / 3600 # Check if process is still running process_status = "unknown" cpu_percent = 0 memory_percent = 0 try: proc = psutil.Process(session.pid) process_status = proc.status() cpu_percent = proc.cpu_percent() memory_percent = proc.memory_percent() except (psutil.NoSuchProcess, psutil.AccessDenied): process_status = "not_found" session_info = { "username": username, "session_id": session.sid, "pid": session.pid, "port": session.port, "model": session.model, "provider": session.provider, "start_time": session.start_time, "duration_hours": round(duration_hours, 2), "log_path": session.log_path, "process_status": process_status, "cpu_percent": round(cpu_percent, 2), "memory_percent": round(memory_percent, 2), } active_sessions.append(session_info) except Exception as e: logger.warning(f"Error processing session for user {username}: {e}") # Still include basic info even if we can't get process details session_info = { "username": username, "session_id": session.sid, "pid": session.pid, "port": session.port, "model": session.model, "provider": session.provider, "start_time": session.start_time, "duration_hours": "unknown", "log_path": session.log_path, "process_status": "error", "cpu_percent": 0, "memory_percent": 0, } active_sessions.append(session_info) # Sort by start time (most recent first) active_sessions.sort( key=lambda x: x["start_time"] if x["start_time"] != "unknown" else "", reverse=True, ) return ( jsonify({"sessions": active_sessions, "count": len(active_sessions)}), 200, ) except Exception as e: logger.error(f"Unexpected error listing sessions: {e}") return (jsonify({"error": "Internal server error"}), 500) @app.post("/api/start") def start_demo(): """Receive the form payload and simulate demo startup. Logs both the raw payload and a safe summary, then returns a dummy iframe URL after a small delay to mimic startup time. """ try: data = request.get_json(force=True, silent=False) except Exception as exc: logger.info("Invalid JSON", data, str(exc)) logger.exception("Invalid JSON body") return jsonify({"ok": False, "error": "invalid_json", "detail": str(exc)}), 400 if not isinstance(data, dict): logger.info("Invalid JSON", data) logger.exception("Invalid JSON body") return jsonify({"ok": False, "error": "invalid_json"}), 400 cookie_session = get_session_from_cookie(request.cookies) try: signin_token = AUTH_SESSION_MANAGEMENT[cookie_session] except KeyError: # weird edge case signin_token = cookie_session # Raw payload logging logger.info( "/api/start payload:\n%s", json.dumps(data, indent=2, ensure_ascii=False) ) # Request metadata and a concise payload summary (avoid dumping large mcp bodies) client_ip = ( (request.headers.get("X-Forwarded-For") or request.remote_addr or "-") .split(",")[0] .strip() ) user_agent = request.headers.get("User-Agent", "-") referer = request.headers.get("Referer", "-") content_type = request.content_type content_length = request.content_length auth_header = request.headers.get("Authorization") user_token = None if auth_header and auth_header.lower().startswith("bearer "): user_token = auth_header.split(" ", 1)[1].strip() logger.info( { "user_agent": user_agent, "referer": referer, "content_type": content_type, "content_length": content_length, "auth_header": auth_header, "user_token": user_token, } ) username = data.get("user") # MCP validation mcp_text = data.get("mcp") if isinstance(data, dict) else None mcp_json_path = None if isinstance(mcp_text, str): try: mcp_data = validate_mcp_file(mcp_text, user_token) mcp_json_path = f"{STORAGE_PATH}/{username}/mcp.json" os.makedirs(f"{STORAGE_PATH}/{username}", exist_ok=True) with open(pathlib.Path(mcp_json_path), "w") as file: json.dump(mcp_data, file, indent=4) except ValueError as e: logger.error(f"MCP file validation failed: {e}") return ( jsonify({"ok": False, "error": "invalid_mcp_file", "detail": str(e)}), 400, ) except Exception as e: logger.error(f"Could not process MCP file: {e}") return ( jsonify( { "ok": False, "error": "mcp_processing_failed", "detail": f"Failed to process MCP file: {str(e)}", } ), 500, ) # Killing previous session logger.info(f"Current SESSION_MANAGEMENT keys: {list(SESSION_MANAGEMENT.keys())}") logger.info(f"Looking for username: {username}") user_session = SESSION_MANAGEMENT.get(username, None) if user_session: logger.info(f"Killing existing session for {username}: {user_session}") cleanup_session_async(SESSION_MANAGEMENT[username]) del SESSION_MANAGEMENT[username] # Actually remove the session user_session = None logger.info( f"Started background cleanup for previous session of user {username}" ) else: logger.info(f"No previous processes to kill for {username}") user_session: UserSession = start_are_process_and_session_lite( model=data.get("model", ""), provider=data.get("provider", ""), username=username, bearer_token=signin_token, user_token=user_token, app_path=mcp_json_path, ) SESSION_MANAGEMENT[username] = user_session logger.info(f"User SESSION {user_session}") logger.info( f"Started session {user_session.sid} on port {user_session.port} for user {user_session.user}" ) iframe_url: str = get_are_url(session=user_session, server="are_simulation") health_url: str = get_are_url(session=user_session, server="health") summary = { "client": { "ip": client_ip, "user_agent": user_agent, "referer": referer, "content_type": content_type, "content_length": content_length, }, "received_fields": { "model": data.get("model") if isinstance(data, dict) else None, "provider": data.get("provider") if isinstance(data, dict) else None, "user": data.get("user") if isinstance(data, dict) else None, # "mcp_length": mcp_len, # "mcp_is_json": mcp_is_json, }, "auth": { "signin_token": signin_token, }, } logger.info("/api/start summary: %s", json.dumps(summary, ensure_ascii=False)) return jsonify({"ok": True, "received": True, "iframe_url": iframe_url, "health_url": health_url}), 200 def run(): """Run the development/Space server.""" port = int(os.environ.get("PORT", "7860")) app.run(host="0.0.0.0", port=port) if __name__ == "__main__": run()