import os import json from datetime import datetime import asyncio import aiohttp from typing import Dict, List, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel, HttpUrl import uvicorn from git_clone import clone_repository # ===== CONFIG ===== class Settings: # Server URLs and Ports CONTROLLER_HOST = "0.0.0.0" # Listen on all interfaces CONTROLLER_PORT = 8000 # This should be the actual IP or hostname where controller is accessible CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", "http://192.168.1.100:8000") # List of tensor server URLs - should be actual IP addresses or hostnames TENSOR_SERVER_URLS = [ url for url in os.getenv("TENSOR_SERVER_URLS", "").split(",") if url ] or [ "https://fred808-ilob.hf.space", "https://fred808-tserv.hf.space", "https://fred808-tserve2.hf.space", ] AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002") # Model settings MODEL_REPO = "https://huggingface.co/inference-net/Schematron-8B" # Server settings TENSOR_SERVER_TIMEOUT = 30 # seconds MAX_ERROR_THRESHOLD = 5 # maximum number of errors SERVER_TIMEOUT = 60 # seconds before marking as error MONITORING_INTERVAL = 15 # seconds between health checks # Dynamic distribution settings @classmethod def get_optimal_chunk_size(cls, total_params: int, num_servers: int) -> int: """Calculate optimal chunk size based on number of servers""" # Aim for 2-3 chunks per server for better parallelism target_chunks = num_servers * 2 return max(1, total_params // target_chunks) @classmethod def get_min_servers_required(cls) -> int: """Dynamically calculate minimum servers needed based on registered servers""" return max(2, len(cls.TENSOR_SERVER_URLS) // 3) # At least 1/3 of registered servers @classmethod def get_min_replica_count(cls, num_servers: int) -> int: """Calculate minimum replicas based on server count""" return max(2, num_servers // 4) # At least 25% of servers should have each chunk # Tokenizer settings MAX_SEQUENCE_LENGTH = 2048 VOCAB_SIZE = 50257 @classmethod def from_env(cls): """Load settings from environment variables""" cls.CONTROLLER_HOST = os.getenv("CONTROLLER_HOST", cls.CONTROLLER_HOST) cls.CONTROLLER_PORT = int(os.getenv("CONTROLLER_PORT", cls.CONTROLLER_PORT)) cls.CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", cls.CONTROLLER_BASE_URL) # Load tensor server URLs from environment tensor_urls = os.getenv("TENSOR_SERVER_URLS") if tensor_urls: cls.TENSOR_SERVER_URLS = tensor_urls.split(",") cls.AGGREGATOR_HOST = os.getenv("AGGREGATOR_HOST", cls.AGGREGATOR_HOST) cls.AGGREGATOR_PORT = int(os.getenv("AGGREGATOR_PORT", cls.AGGREGATOR_PORT)) cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", f"http://{cls.AGGREGATOR_HOST}:{cls.AGGREGATOR_PORT}") return cls # ===== State Models ===== class ServerMetrics(BaseModel): """Metrics for tensor server performance and load""" cpu_usage: float = 0.0 memory_usage: float = 0.0 gpu_usage: Optional[float] = None active_requests: int = 0 total_requests: int = 0 average_response_time: float = 0.0 last_error: Optional[str] = None error_count: int = 0 class TensorServer(BaseModel): """Represents a registered tensor server""" url: HttpUrl status: str = "initializing" # initializing, ready, busy, error, degraded last_heartbeat: datetime = datetime.now() model_chunks: List[int] = [] # List of chunk IDs assigned to this server metrics: ServerMetrics = ServerMetrics() version: str = "1.0.0" capabilities: Dict[str, bool] = { "gpu_available": False, "quantization_support": False, "tensor_parallelism": False } class ModelChunk(BaseModel): """Represents a chunk of the model to be sent to a tensor server""" chunk_id: int files: List[str] # files included in this chunk config: Dict # configuration for this chunk size_bytes: int = 0 server_assignments: List[str] = [] # URLs of servers holding this chunk status: str = "unassigned" # unassigned, assigned, loaded, error metrics: Dict[str, float] = { "load_time": 0.0, "memory_usage": 0.0, "average_inference_time": 0.0 } # ===== FastAPI App ===== app = FastAPI( title="Florence-2 Model Controller", description="Controls model distribution across tensor servers", version="1.0.0" ) # ===== Global State ===== class ControllerState: def __init__(self): self.model_files: Dict[str, str] = {} # Mapping of filename to file path self.model_config: Dict = {} # Model configuration self.tensor_servers: Dict[str, TensorServer] = {} self.model_chunks: Dict[int, ModelChunk] = {} self.is_model_loaded = False self.model_path: str = "" # Base path where model files are stored self.chunks_dir: str = "" # Directory containing chunk files self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations state = ControllerState() # ===== Helper Functions ===== async def split_model_weights(): """Split model files into chunks based on available servers without loading into memory""" try: import os import math import shutil from pathlib import Path # Find model weight files (safetensors or pytorch) weight_files = [f for f in state.model_files.values() if f.endswith(('.safetensors', '.bin'))] if not weight_files: raise Exception("No model weight files found") # The current splitting logic only supports splitting a single file. # If there are multiple files, we assume they are sharded and need a different approach. # For now, we will select the largest file to split, or the first one if all are small. model_file = max(weight_files, key=os.path.getsize) if len(weight_files) > 1 else weight_files[0] if len(weight_files) > 1: print(f"[WARN] Found multiple weight files. Selecting the largest one for splitting: {model_file}") else: print(f"[INFO] Found model weight file: {model_file}") # Get file size and calculate chunks try: with open(model_file, 'rb') as f: # Get actual file size by seeking to end f.seek(0, 2) # Seek to end file_size = f.tell() # Get position (total size) f.seek(0) # Reset to beginning # Read first few bytes to verify file isn't corrupted header = f.read(8) if len(header) == 0: raise ValueError(f"File is empty: {model_file}") except Exception as e: raise Exception(f"Failed to read model file {model_file}: {str(e)}") # Verify file size is reasonable if file_size < 1024: # Less than 1KB raise ValueError(f"Model file suspiciously small ({file_size} bytes). Possible corruption or incomplete download.") num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS) num_chunks = num_servers # One chunk per server initially chunk_size = math.ceil(file_size / num_chunks) # Format sizes for display def format_size(size_bytes): if size_bytes >= 1024*1024*1024: # GB return f"{size_bytes / (1024*1024*1024):.2f} GB ({size_bytes:,} bytes)" elif size_bytes >= 1024*1024: # MB return f"{size_bytes / (1024*1024):.2f} MB ({size_bytes:,} bytes)" elif size_bytes >= 1024: # KB return f"{size_bytes / 1024:.2f} KB ({size_bytes:,} bytes)" else: return f"{size_bytes:,} bytes" print(f"[INFO] Model file size: {format_size(file_size)}") print(f"[INFO] Creating {num_chunks} chunks of approximately {format_size(chunk_size)} each") # Use the chunks directory from state os.makedirs(state.chunks_dir, exist_ok=True) # Split the file into chunks with open(model_file, 'rb') as f: chunk_sizes = [] # Track actual chunk sizes for chunk_id in range(num_chunks): chunk_path = os.path.join(state.chunks_dir, f"chunk_{chunk_id}.bin") # Calculate chunk boundaries start_pos = chunk_id * chunk_size remaining = file_size - start_pos current_chunk_size = min(chunk_size, remaining) if current_chunk_size <= 0: break # Read and write chunk try: f.seek(start_pos) chunk_data = f.read(current_chunk_size) actual_chunk_size = len(chunk_data) if actual_chunk_size != current_chunk_size: print(f"[WARN] Chunk {chunk_id} size mismatch. Expected: {current_chunk_size}, Got: {actual_chunk_size}") with open(chunk_path, 'wb') as chunk_file: chunk_file.write(chunk_data) chunk_sizes.append(actual_chunk_size) print(f"[DEBUG] Chunk {chunk_id} data: First few bytes: {chunk_data[:20].hex()}") except Exception as e: raise Exception(f"Failed to process chunk {chunk_id} at offset {start_pos}: {str(e)}") # Create chunk metadata # Assign vocab_offset based on cumulative sizes of earlier chunks # so that chunks map to disjoint vocab ranges for aggregation. cumulative = 0 for cid, c in state.model_chunks.items(): try: cumulative += int(c.config.get('shard_dim', c.config.get('size_bytes', 1))) except Exception: cumulative += 1 cfg = { "start_offset": start_pos, "size_bytes": current_chunk_size, "is_last_chunk": chunk_id == num_chunks - 1, "total_chunks": num_chunks, "original_file": os.path.basename(model_file), # minimal shard mapping; users should adjust shard_dim to real local vocab size "vocab_offset": cumulative, # shard_dim should reflect how many vocab ids this chunk covers. # Default to 1 when unknown; prefer explicitly setting this in chunk metadata. "shard_dim": int(cfg.get('shard_dim', 1)) if isinstance(cfg := {} , dict) else 1 } state.model_chunks[chunk_id] = ModelChunk( chunk_id=chunk_id, files=[f"chunk_{chunk_id}.bin"], config=cfg, size_bytes=current_chunk_size, status="ready" ) print(f"[INFO] Created chunk {chunk_id}: {format_size(current_chunk_size)} ({current_chunk_size:,} bytes)") # Verify distribution total_size_actual = sum(chunk_sizes) if total_size_actual != file_size: print(f"[WARN] Total chunk size ({format_size(total_size_actual)}) differs from original file size ({format_size(file_size)})") print(f"[WARN] Difference: {format_size(abs(total_size_actual - file_size))}") # Calculate statistics avg_chunk_size = sum(chunk_sizes) / len(chunk_sizes) if chunk_sizes else 0 min_chunk_size = min(chunk_sizes) if chunk_sizes else 0 max_chunk_size = max(chunk_sizes) if chunk_sizes else 0 print(f"\n[INFO] Distribution Summary:") print(f"- Original file: {os.path.basename(model_file)}") print(f"- Total size: {format_size(file_size)} ({file_size:,} bytes)") print(f"- Number of chunks: {len(state.model_chunks)}") print(f"- Chunks directory: {state.chunks_dir}") print(f"- Average chunk size: {format_size(avg_chunk_size)}") print(f"- Smallest chunk: {format_size(min_chunk_size)}") print(f"- Largest chunk: {format_size(max_chunk_size)}") print(f"- Size variance: {((max_chunk_size - min_chunk_size) / avg_chunk_size * 100):.1f}%") return True except Exception as e: print(f"[ERROR] Failed to split model weights: {str(e)}") return False # Calculate total model size and chunks total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values()) num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS) # Determine optimal number of chunks based on server count # If 2 servers -> 2 chunks (500MB each for 1GB) # If 3 servers -> 3 chunks (333MB each for 1GB) num_chunks = num_servers bytes_per_chunk = math.ceil(total_size_bytes / num_chunks) print(f"[INFO] Total model size: {total_size_bytes / (1024*1024*1024):.2f} GB") print(f"[INFO] Available servers: {num_servers}") print(f"[INFO] Creating {num_chunks} chunks") print(f"[INFO] Target chunk size: {bytes_per_chunk / (1024*1024):.2f} MB") current_chunk = [] current_chunk_size = 0 chunk_id = 0 chunk_sizes = [] # Track actual chunk sizes for verification # Sort weights by size for better distribution sorted_weights = sorted( weights.items(), key=lambda x: x[1].nelement() * x[1].element_size(), reverse=True ) for key, tensor in weights.items(): tensor_size = tensor.numel() # Calculate tensor size in bytes tensor_size = tensor.nelement() * tensor.element_size() # If adding this tensor would exceed chunk size and we have tensors in current chunk if (current_chunk_size + tensor_size > bytes_per_chunk and current_chunk) or \ (chunk_id == num_chunks - 1): # Last chunk gets remaining tensors # Save current chunk chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors") chunk_weights = {k: weights[k] for k in current_chunk} torch.save(chunk_weights, chunk_path) # Calculate chunk stats chunk_total_size = sum(weights[k].nelement() * weights[k].element_size() for k in current_chunk) chunk_sizes.append(chunk_total_size) # Create chunk metadata state.model_chunks[chunk_id] = ModelChunk( chunk_id=chunk_id, files=[f"chunk_{chunk_id}.safetensors"], config={ "weight_keys": current_chunk, "size_bytes": chunk_total_size, "num_parameters": sum(weights[k].nelement() for k in current_chunk), "input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0, "output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0, # assign a vocab_offset cumulatively "vocab_offset": sum(int(c.config.get('shard_dim', 1)) for c in state.model_chunks.values()), # Default shard_dim to 1; set correct value in chunk metadata if known "shard_dim": int(1) } ) print(f"[INFO] Created chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, " f"{len(current_chunk)} tensors") # Reset for next chunk current_chunk = [] current_chunk_size = 0 chunk_id += 1 # If we've created all chunks except last one, put remaining tensors in last chunk if chunk_id == num_chunks - 1: remaining_tensors = [k for k, _ in sorted_weights if k not in sum([c.config["weight_keys"] for c in state.model_chunks.values()], [])] current_chunk.extend(remaining_tensors) continue # Add tensor to current chunk current_chunk.append(key) current_chunk_size += tensor_size # Save last chunk if not empty if current_chunk: chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors") chunk_weights = {k: weights[k] for k in current_chunk} torch.save(chunk_weights, chunk_path) # Calculate final chunk stats chunk_total_size = sum(weights[k].nelement() * weights[k].element_size() for k in current_chunk) chunk_sizes.append(chunk_total_size) state.model_chunks[chunk_id] = ModelChunk( chunk_id=chunk_id, files=[f"chunk_{chunk_id}.safetensors"], config={ "weight_keys": current_chunk, "size_bytes": chunk_total_size, "num_parameters": sum(weights[k].nelement() for k in current_chunk), "input_size": weights[current_chunk[0]].size(1), "output_size": weights[current_chunk[-1]].size(0) } ) print(f"[INFO] Created final chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, " f"{len(current_chunk)} tensors") # Verify distribution total_size_actual = sum(chunk_sizes) size_std_dev = torch.tensor(chunk_sizes).std().item() / (1024*1024) # MB size_mean = torch.tensor(chunk_sizes).mean().item() / (1024*1024) # MB print(f"\n[INFO] Distribution Summary:") print(f"- Total model size: {total_size_actual / (1024*1024*1024):.2f} GB") print(f"- Number of chunks: {len(state.model_chunks)}") print(f"- Average chunk size: {size_mean:.2f} MB") print(f"- Chunk size std dev: {size_std_dev:.2f} MB") print(f"- Size variation: {(size_std_dev/size_mean*100):.1f}%") # Verify all weights were distributed all_distributed = set(sum([c.config["weight_keys"] for c in state.model_chunks.values()], [])) if len(all_distributed) != len(weights): missing = set(weights.keys()) - all_distributed print(f"[WARN] Some weights were not distributed: {missing}") return True except Exception as e: print(f"[ERROR] Failed to split model weights: {str(e)}") return False async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict): """Send a model chunk to a tensor server""" try: print(f"[INFO] Sending chunk {chunk_id} to server {server_url}") chunk_path = os.path.join(state.chunks_dir, f"chunk_{chunk_id}.bin") if not os.path.exists(chunk_path): raise Exception(f"Chunk file not found: {chunk_path}") # Get chunk metadata chunk = state.model_chunks[chunk_id] chunk_data = { 'chunk_id': chunk_id, 'files': [os.path.basename(chunk_path)], 'config': chunk.config } async with aiohttp.ClientSession() as session: # Step 1: Send chunk configuration async with session.post( f"{server_url}/load_chunk", json=chunk_data, timeout=Settings.TENSOR_SERVER_TIMEOUT ) as response: if response.status != 200: error_msg = await response.text() raise Exception(f"Failed to register chunk: {error_msg}") result = await response.json() if not result.get("ready_for_data", False): raise Exception("Server not ready for chunk data") # Step 2: Upload chunk data with open(chunk_path, 'rb') as f: chunk_file = f.read() form = aiohttp.FormData() form.add_field('file', chunk_file, filename=os.path.basename(chunk_path), content_type='application/octet-stream') async with session.post( f"{server_url}/upload_chunk_data/{chunk_id}", data=form, timeout=Settings.TENSOR_SERVER_TIMEOUT ) as upload_response: if upload_response.status != 200: error_msg = await upload_response.text() raise Exception(f"Failed to upload chunk data: {error_msg}") upload_result = await upload_response.json() print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)") return True except Exception as e: print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}") return False async def distribute_model_chunks(): """Distribute model chunks across available tensor servers""" try: available_servers = [ server for server in state.tensor_servers.values() if server.status in ["ready", "busy"] and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD ] min_required = Settings.get_min_servers_required() if len(available_servers) < min_required: raise Exception(f"Not enough healthy servers. Need {min_required}, got {len(available_servers)}") # Create or update weight chunks based on current server count if not state.model_chunks or len(state.model_chunks) > len(available_servers) * 3: if not await split_model_weights(): raise Exception("Failed to split model weights") # Prepare for parallel distribution tasks = [] min_replicas = Settings.get_min_replica_count(len(available_servers)) chunks_per_server = len(state.model_chunks) / len(available_servers) print(f"[INFO] Distributing chunks with min {min_replicas} replicas per chunk") print(f"[INFO] Target chunks per server: {chunks_per_server:.1f}") # Distribute chunks for chunk_id, chunk in state.model_chunks.items(): # Calculate optimal number of replicas based on chunk size and server capacity target_replicas = max(min_replicas, int(chunks_per_server * len(available_servers) / len(state.model_chunks))) current_assignments = set(chunk.server_assignments) current_healthy = [url for url in current_assignments if state.tensor_servers[url].status in ["ready", "busy"]] # Remove unhealthy assignments chunk.server_assignments = current_healthy # Add new assignments if needed while len(chunk.server_assignments) < target_replicas: # Find least loaded eligible server eligible_servers = [ server for server in available_servers if str(server.url) not in chunk.server_assignments and len(server.model_chunks) < (len(state.model_chunks) / len(available_servers) * 1.5) ] if not eligible_servers: break # Sort by load and error count eligible_servers.sort(key=lambda s: ( len(s.model_chunks), s.metrics.error_count, s.metrics.cpu_usage )) # Assign to best server best_server = eligible_servers[0] chunk.server_assignments.append(str(best_server.url)) best_server.model_chunks.append(chunk_id) print(f"[INFO] Assigned chunk {chunk_id} to server {best_server.url}") return True except Exception as e: print(f"[ERROR] Failed to distribute model chunks: {str(e)}") return False async def monitor_tensor_servers(): """Periodically check health and update metrics of all tensor servers""" while True: for server_url, server in state.tensor_servers.items(): try: # Check basic health is_healthy = await check_tensor_server_health(server_url) if not is_healthy: server.status = "error" server.metrics.error_count += 1 print(f"[WARN] Server {server_url} is unhealthy") continue # Get detailed metrics async with aiohttp.ClientSession() as session: async with session.get(f"{server_url}/metrics", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response: if response.status == 200: metrics = await response.json() server.metrics = ServerMetrics(**metrics) # Update server status based on metrics if server.metrics.error_count > Settings.MAX_ERROR_THRESHOLD: server.status = "degraded" elif server.metrics.cpu_usage > 90 or server.metrics.memory_usage > 90: server.status = "busy" else: server.status = "ready" server.last_heartbeat = datetime.now() except Exception as e: print(f"[ERROR] Failed to monitor server {server_url}: {str(e)}") server.status = "error" server.metrics.last_error = str(e) server.metrics.error_count += 1 # Check for servers that haven't responded in a while current_time = datetime.now() for server_url, server in state.tensor_servers.items(): if (current_time - server.last_heartbeat).seconds > Settings.SERVER_TIMEOUT: print(f"[WARN] Server {server_url} hasn't responded in {Settings.SERVER_TIMEOUT} seconds") server.status = "error" await asyncio.sleep(Settings.MONITORING_INTERVAL) def get_next_model_version(base_dir: str, model_name: str) -> int: """Get the next available version number for the model""" existing_versions = [] model_base_dir = os.path.join(base_dir, model_name) if os.path.exists(model_base_dir): for d in os.listdir(model_base_dir): if d.startswith('v') and d[1:].isdigit(): existing_versions.append(int(d[1:])) return max(existing_versions + [0]) + 1 def check_existing_model(model_path: str) -> bool: """Check if a model exists and has required files""" if not os.path.exists(model_path): return False # Check for essential files required_files = ['config.json'] model_files = os.listdir(model_path) # Check for any weight files has_weights = any(f.endswith(('.bin', '.safetensors')) for f in model_files) return all(f in model_files for f in required_files) and has_weights async def download_model_files(): """Downloads the model files using Hugging Face Hub API""" try: print(f"[INFO] Processing model from {Settings.MODEL_REPO}...") # Install required packages if not present required_packages = ["huggingface_hub", "requests", "tqdm"] for package in required_packages: try: __import__(package) except ImportError: print(f"[INFO] Installing {package}...") import subprocess subprocess.check_call(["pip", "install", package]) from huggingface_hub import hf_hub_download, snapshot_download, HfFolder import requests from tqdm import tqdm # Create models directory models_dir = os.path.join(os.getcwd(), "models") os.makedirs(models_dir, exist_ok=True) print(f"[INFO] Models directory: {models_dir}") # Get the model name from the repository URL repo_id = "/".join(Settings.MODEL_REPO.split('/')[-2:]) # e.g., "facebook/opt-125m" model_name = repo_id.split('/')[-1] # Create versioned model directory version = get_next_model_version(models_dir, model_name) model_base_dir = os.path.join(models_dir, model_name) model_version_dir = os.path.join(model_base_dir, f"v{version}") # Function to download file with progress bar def download_file(url, filename): response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(filename, 'wb') as f, tqdm( desc=os.path.basename(filename), total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as pbar: for data in response.iter_content(chunk_size=1024): size = f.write(data) pbar.update(size) # Check if previous version exists and is valid if version > 1: prev_version_dir = os.path.join(model_base_dir, f"v{version-1}") if check_existing_model(prev_version_dir): print(f"[INFO] Using existing model from {prev_version_dir}") model_path = prev_version_dir state.is_model_loaded = True else: # Download new version os.makedirs(model_version_dir, exist_ok=True) model_path = model_version_dir else: # First time download os.makedirs(model_version_dir, exist_ok=True) model_path = model_version_dir if not state.is_model_loaded: try: print(f"[INFO] Downloading model files from {repo_id}...") # Use snapshot_download to get all necessary files at once, which supports all weight file names print("[INFO] Downloading all model files (this may take a while)...") # snapshot_download is the most robust way to get all files matching patterns # This addresses the user's request to download model files that are not just "pytorch.bin" model_path = snapshot_download( repo_id=repo_id, local_dir=model_path, allow_patterns=["*.bin", "*.safetensors", "*.json", "*.txt", "tokenizer.model"], ignore_patterns=["*.msgpack", "*.onnx"], # Ignore non-PyTorch/safetensors formats force_download=True ) print(f"[INFO] All files downloaded to {model_path}") state.is_model_loaded = True except Exception as e: raise Exception(f"Failed to download model files: {str(e)}") # Set model paths in state state.model_path = model_path state.chunks_dir = os.path.join(model_path, "chunks") os.makedirs(state.chunks_dir, exist_ok=True) # Load and parse the config config_path = os.path.join(model_path, "config.json") if os.path.exists(config_path): with open(config_path, 'r') as f: state.model_config = json.load(f) print("[INFO] Loaded model configuration") print(f"[INFO] Model type: {state.model_config.get('model_type', 'unknown')}") print(f"[INFO] Architecture: {state.model_config.get('architectures', ['unknown'])[0]}") else: print("[WARN] No config.json found in model directory") # Scan for model files print("[INFO] Scanning for model files...") for root, _, files in os.walk(model_path): for file in files: if file.endswith(('.bin', '.json', '.safetensors')): file_path = os.path.join(root, file) state.model_files[file] = file_path print(f"[INFO] Found model file: {file}") if state.model_files: state.is_model_loaded = True print(f"[INFO] Model files found successfully! Total files: {len(state.model_files)}") print(f"[INFO] Model location: {model_path}") return True else: raise ValueError("No model files were found in the repository") except Exception as e: print(f"[ERROR] Failed to process model files: {e}") state.is_model_loaded = False raise async def check_tensor_server_health(url: HttpUrl) -> bool: """Checks if a tensor server is healthy""" try: async with aiohttp.ClientSession() as session: async with session.get(f"{url}/health", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response: return response.status == 200 except: return False # ===== API Endpoints ===== async def execute_tensor_operation(operation_id: str, server_url: HttpUrl, operation: str, data: Dict): """Execute an operation on a tensor server and wait for results""" try: async with aiohttp.ClientSession() as session: # Start the operation async with session.post( f"{server_url}/{operation}", json=data, timeout=Settings.TENSOR_SERVER_TIMEOUT ) as response: if response.status != 200: error_msg = await response.text() raise HTTPException( status_code=response.status, detail=f"Operation failed on server {server_url}: {error_msg}" ) initial_response = await response.json() if initial_response.get("status") == "completed": # Operation completed immediately state.operation_results[operation_id] = initial_response return initial_response # Operation is async, poll for results while True: await asyncio.sleep(1) # Poll interval async with session.get( f"{server_url}/operation/{initial_response['operation_id']}", timeout=Settings.TENSOR_SERVER_TIMEOUT ) as status_response: if status_response.status != 200: raise HTTPException( status_code=status_response.status, detail=f"Failed to get operation status from {server_url}" ) status_data = await status_response.json() if status_data["status"] in ["completed", "failed"]: state.operation_results[operation_id] = status_data if status_data["status"] == "failed": raise HTTPException( status_code=500, detail=f"Operation failed on server {server_url}: {status_data.get('error')}" ) return status_data except asyncio.TimeoutError: raise HTTPException( status_code=504, detail=f"Operation timed out on server {server_url}" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Error executing operation on {server_url}: {str(e)}" ) @app.post("/execute/{operation}") async def execute_operation(operation: str, data: Dict): """Execute an operation across tensor servers and collect results""" operation_id = f"{operation}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{len(state.operation_results)}" # Get available servers with required chunks available_servers = [ server for server in state.tensor_servers.values() if server.status in ["ready", "busy"] and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD ] if not available_servers: raise HTTPException( status_code=503, detail="No available tensor servers" ) # Start operations on all relevant servers in parallel tasks = [] for server in available_servers: if operation in ["compute", "forward"]: # For compute operations, only use servers with required chunks required_chunks = data.get("required_chunks", []) if not all(chunk_id in server.model_chunks for chunk_id in required_chunks): continue task = asyncio.create_task( execute_tensor_operation( f"{operation_id}_{server.url}", server.url, operation, data ) ) tasks.append(task) state.pending_operations[f"{operation_id}_{server.url}"] = task if not tasks: raise HTTPException( status_code=400, detail="No servers available with required model chunks" ) try: # Wait for all operations to complete results = await asyncio.gather(*tasks) # Process and aggregate results aggregated_result = { "operation_id": operation_id, "status": "completed", "server_results": results, "timestamp": datetime.now().isoformat() } # Clean up for task_id in list(state.pending_operations.keys()): if task_id.startswith(operation_id): del state.pending_operations[task_id] return aggregated_result except Exception as e: # Cancel any remaining tasks for task in tasks: if not task.done(): task.cancel() # Clean up for task_id in list(state.pending_operations.keys()): if task_id.startswith(operation_id): del state.pending_operations[task_id] raise HTTPException( status_code=500, detail=f"Operation failed: {str(e)}" ) @app.get("/operation/{operation_id}") async def get_operation_status(operation_id: str): """Get the status of an operation""" # Check completed operations results = { k: v for k, v in state.operation_results.items() if k.startswith(operation_id) } if results: return { "operation_id": operation_id, "status": "completed", "results": results } # Check pending operations pending = { k: "running" for k in state.pending_operations.keys() if k.startswith(operation_id) } if pending: return { "operation_id": operation_id, "status": "running", "pending_servers": list(pending.keys()) } raise HTTPException( status_code=404, detail=f"Operation {operation_id} not found" ) @app.get("/") async def root(): """Health check endpoint""" return { "status": "running", "model_loaded": state.is_model_loaded, "registered_servers": len(state.tensor_servers), "downloaded_files": len(state.model_files), "config_loaded": bool(state.model_config) } @app.get("/health") async def health_check(): """Detailed health check""" return { "status": "healthy", "model_loaded": state.is_model_loaded, "registered_servers": len(state.tensor_servers), "downloaded_files": list(state.model_files.keys()), "config_loaded": bool(state.model_config), "model_type": state.model_config.get("model_type", "unknown") } @app.post("/register_tensor_server") async def register_tensor_server(server_url: HttpUrl): """Register a new tensor server""" if not await check_tensor_server_health(server_url): raise HTTPException(status_code=400, detail="Tensor server is not healthy") state.tensor_servers[str(server_url)] = TensorServer(url=server_url) print(f"[INFO] Registered new tensor server at {server_url}") # If model is loaded, automatically distribute chunks if state.is_model_loaded: print(f"[INFO] Model is loaded, starting distribution for new server {server_url}") try: # Create chunks if they don't exist if not state.model_chunks: if await split_model_weights(): print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") else: print("[ERROR] Failed to split model weights") # Distribute chunks if await distribute_model_chunks(): print("[INFO] Successfully distributed chunks to tensor servers") else: print("[ERROR] Failed to distribute chunks") except Exception as e: print(f"[ERROR] Distribution error during server registration: {str(e)}") return { "status": "registered", "registered_servers": len(state.tensor_servers), "server_id": str(server_url), "model_loaded": state.is_model_loaded, "chunks_distributed": len(state.model_chunks) if state.model_chunks else 0 } @app.delete("/unregister_tensor_server") async def unregister_tensor_server(server_url: HttpUrl): """Unregister a tensor server""" if str(server_url) in state.tensor_servers: # Remove server assignments from chunks for chunk in state.model_chunks.values(): if str(server_url) in chunk.server_assignments: chunk.server_assignments.remove(str(server_url)) del state.tensor_servers[str(server_url)] print(f"[INFO] Unregistered tensor server at {server_url}") # Trigger redistribution of chunks await distribute_model_chunks() return {"status": "unregistered"} raise HTTPException(status_code=404, detail="Server not found") @app.get("/server/{server_url}/chunks") async def get_server_chunks(server_url: HttpUrl): """Get the chunks assigned to a specific server""" if str(server_url) not in state.tensor_servers: raise HTTPException(status_code=404, detail="Server not found") server = state.tensor_servers[str(server_url)] assigned_chunks = [ state.model_chunks[chunk_id] for chunk_id in server.model_chunks ] return { "server_status": server.status, "assigned_chunks": assigned_chunks, "metrics": server.metrics.dict() } @app.post("/redistribute") async def redistribute_chunks(): """Manually trigger redistribution of model chunks""" success = await distribute_model_chunks() if not success: raise HTTPException(status_code=500, detail="Failed to redistribute chunks") return { "status": "redistributed", "chunk_assignments": { chunk_id: chunk.server_assignments for chunk_id, chunk in state.model_chunks.items() } } @app.get("/chunks/{chunk_id}/status") async def get_chunk_status(chunk_id: int): """Get the status and assignments of a specific chunk""" if chunk_id not in state.model_chunks: raise HTTPException(status_code=404, detail="Chunk not found") chunk = state.model_chunks[chunk_id] return { "chunk_id": chunk_id, "status": chunk.status, "server_assignments": chunk.server_assignments, "metrics": chunk.metrics } @app.post("/initialize") async def initialize_system(): """Download model files and prepare for distribution""" await download_model_files() # Verify downloaded files files_status = {} total_size = 0 for filename, filepath in state.model_files.items(): exists = os.path.exists(filepath) if exists: size = os.path.getsize(filepath) total_size += size files_status[filename] = {"exists": exists, "size_bytes": size} else: files_status[filename] = {"exists": exists, "size_bytes": 0} # Start model distribution if we have tensor servers distribution_status = "not_started" if state.tensor_servers: print("[INFO] Starting automatic model distribution...") try: # Split model into chunks if await split_model_weights(): print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") # Distribute chunks to servers if await distribute_model_chunks(): print("[INFO] Successfully distributed chunks to tensor servers") distribution_status = "completed" else: print("[ERROR] Failed to distribute chunks") distribution_status = "distribution_failed" else: print("[ERROR] Failed to split model weights") distribution_status = "split_failed" except Exception as e: print(f"[ERROR] Distribution error: {str(e)}") distribution_status = f"error: {str(e)}" else: print("[INFO] No tensor servers registered yet. Will distribute when servers register.") return { "status": "initialized", "model_loaded": state.is_model_loaded, "files_status": files_status, "total_size_bytes": total_size, "config_loaded": bool(state.model_config), "model_type": state.model_config.get("model_type", "unknown"), "architecture": state.model_config.get("architectures", ["unknown"])[0], "distribution_status": distribution_status, "registered_servers": len(state.tensor_servers), "chunks_created": len(state.model_chunks) if state.model_chunks else 0 } # ===== Main Execution ===== @app.on_event("startup") async def startup_event(): """Initialize the server and start distribution""" print("[INFO] Initializing system...") try: # Initialize system and download model await initialize_system() print("[INFO] Model initialization complete") # Split model into chunks if await split_model_weights(): print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") # Distribute chunks to tensor servers print("[INFO] Starting chunk distribution...") distribution_tasks = [] # One chunk per server distribution for chunk_id, chunk in state.model_chunks.items(): # Assign each chunk to exactly one server server_index = chunk_id % len(Settings.TENSOR_SERVER_URLS) server_url = Settings.TENSOR_SERVER_URLS[server_index] # Create task for distributing this chunk to its assigned server task = asyncio.create_task( send_chunk_to_server(server_url, chunk_id, {"chunk_id": chunk_id}) ) distribution_tasks.append(task) print(f"[INFO] Sending chunk {chunk_id} to {server_url}") # Track assignments for future reference try: chunk.server_assignments.append(server_url) except Exception: pass if distribution_tasks: print(f"[INFO] Distributing {len(distribution_tasks)} chunks...") results = await asyncio.gather(*distribution_tasks, return_exceptions=True) success_count = sum(1 for r in results if r is True) print(f"[INFO] Successfully distributed {success_count} chunks out of {len(distribution_tasks)} attempts") else: print("[ERROR] Failed to split model weights") except Exception as e: print(f"[ERROR] Startup error: {str(e)}") print("[INFO] Startup complete") if __name__ == "__main__": port = int(os.getenv("PORT", 8000)) print(f"[INFO] Starting controller server on port {port}") print(f"[INFO] API Documentation available at http://localhost:{port}/docs") uvicorn.run( "controller_server_new:app", host="0.0.0.0", port=port, reload=False )