|
|
import os |
|
|
import json |
|
|
from datetime import datetime |
|
|
import asyncio |
|
|
import aiohttp |
|
|
from typing import Dict, List, Optional |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel, HttpUrl |
|
|
import uvicorn |
|
|
from git_clone import clone_repository |
|
|
|
|
|
|
|
|
class Settings: |
|
|
|
|
|
CONTROLLER_HOST = "0.0.0.0" |
|
|
CONTROLLER_PORT = 8000 |
|
|
|
|
|
CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", "http://192.168.1.100:8000") |
|
|
|
|
|
|
|
|
TENSOR_SERVER_URLS = [ |
|
|
url for url in os.getenv("TENSOR_SERVER_URLS", "").split(",") if url |
|
|
] or [ |
|
|
"https://fred808-ilob.hf.space", |
|
|
"https://fred808-tserv.hf.space", |
|
|
"https://fred808-tserve2.hf.space", |
|
|
] |
|
|
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002") |
|
|
|
|
|
|
|
|
MODEL_REPO = "https://huggingface.co/inference-net/Schematron-8B" |
|
|
|
|
|
|
|
|
TENSOR_SERVER_TIMEOUT = 30 |
|
|
MAX_ERROR_THRESHOLD = 5 |
|
|
SERVER_TIMEOUT = 60 |
|
|
MONITORING_INTERVAL = 15 |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def get_optimal_chunk_size(cls, total_params: int, num_servers: int) -> int: |
|
|
"""Calculate optimal chunk size based on number of servers""" |
|
|
|
|
|
target_chunks = num_servers * 2 |
|
|
return max(1, total_params // target_chunks) |
|
|
|
|
|
@classmethod |
|
|
def get_min_servers_required(cls) -> int: |
|
|
"""Dynamically calculate minimum servers needed based on registered servers""" |
|
|
return max(2, len(cls.TENSOR_SERVER_URLS) // 3) |
|
|
|
|
|
@classmethod |
|
|
def get_min_replica_count(cls, num_servers: int) -> int: |
|
|
"""Calculate minimum replicas based on server count""" |
|
|
return max(2, num_servers // 4) |
|
|
|
|
|
|
|
|
MAX_SEQUENCE_LENGTH = 2048 |
|
|
VOCAB_SIZE = 50257 |
|
|
|
|
|
@classmethod |
|
|
def from_env(cls): |
|
|
"""Load settings from environment variables""" |
|
|
cls.CONTROLLER_HOST = os.getenv("CONTROLLER_HOST", cls.CONTROLLER_HOST) |
|
|
cls.CONTROLLER_PORT = int(os.getenv("CONTROLLER_PORT", cls.CONTROLLER_PORT)) |
|
|
cls.CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", cls.CONTROLLER_BASE_URL) |
|
|
|
|
|
|
|
|
tensor_urls = os.getenv("TENSOR_SERVER_URLS") |
|
|
if tensor_urls: |
|
|
cls.TENSOR_SERVER_URLS = tensor_urls.split(",") |
|
|
|
|
|
cls.AGGREGATOR_HOST = os.getenv("AGGREGATOR_HOST", cls.AGGREGATOR_HOST) |
|
|
cls.AGGREGATOR_PORT = int(os.getenv("AGGREGATOR_PORT", cls.AGGREGATOR_PORT)) |
|
|
cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", |
|
|
f"http://{cls.AGGREGATOR_HOST}:{cls.AGGREGATOR_PORT}") |
|
|
|
|
|
return cls |
|
|
|
|
|
|
|
|
class ServerMetrics(BaseModel): |
|
|
"""Metrics for tensor server performance and load""" |
|
|
cpu_usage: float = 0.0 |
|
|
memory_usage: float = 0.0 |
|
|
gpu_usage: Optional[float] = None |
|
|
active_requests: int = 0 |
|
|
total_requests: int = 0 |
|
|
average_response_time: float = 0.0 |
|
|
last_error: Optional[str] = None |
|
|
error_count: int = 0 |
|
|
|
|
|
class TensorServer(BaseModel): |
|
|
"""Represents a registered tensor server""" |
|
|
url: HttpUrl |
|
|
status: str = "initializing" |
|
|
last_heartbeat: datetime = datetime.now() |
|
|
model_chunks: List[int] = [] |
|
|
metrics: ServerMetrics = ServerMetrics() |
|
|
version: str = "1.0.0" |
|
|
capabilities: Dict[str, bool] = { |
|
|
"gpu_available": False, |
|
|
"quantization_support": False, |
|
|
"tensor_parallelism": False |
|
|
} |
|
|
|
|
|
class ModelChunk(BaseModel): |
|
|
"""Represents a chunk of the model to be sent to a tensor server""" |
|
|
chunk_id: int |
|
|
files: List[str] |
|
|
config: Dict |
|
|
size_bytes: int = 0 |
|
|
server_assignments: List[str] = [] |
|
|
status: str = "unassigned" |
|
|
metrics: Dict[str, float] = { |
|
|
"load_time": 0.0, |
|
|
"memory_usage": 0.0, |
|
|
"average_inference_time": 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Florence-2 Model Controller", |
|
|
description="Controls model distribution across tensor servers", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
class ControllerState: |
|
|
def __init__(self): |
|
|
self.model_files: Dict[str, str] = {} |
|
|
self.model_config: Dict = {} |
|
|
self.tensor_servers: Dict[str, TensorServer] = {} |
|
|
self.model_chunks: Dict[int, ModelChunk] = {} |
|
|
self.is_model_loaded = False |
|
|
self.model_path: str = "" |
|
|
self.chunks_dir: str = "" |
|
|
self.operation_results: Dict[str, Dict] = {} |
|
|
self.pending_operations: Dict[str, asyncio.Task] = {} |
|
|
|
|
|
state = ControllerState() |
|
|
|
|
|
|
|
|
async def split_model_weights(): |
|
|
"""Split model files into chunks based on available servers without loading into memory""" |
|
|
try: |
|
|
import os |
|
|
import math |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
weight_files = [f for f in state.model_files.values() if f.endswith(('.safetensors', '.bin'))] |
|
|
|
|
|
if not weight_files: |
|
|
raise Exception("No model weight files found") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_file = max(weight_files, key=os.path.getsize) if len(weight_files) > 1 else weight_files[0] |
|
|
|
|
|
if len(weight_files) > 1: |
|
|
print(f"[WARN] Found multiple weight files. Selecting the largest one for splitting: {model_file}") |
|
|
else: |
|
|
print(f"[INFO] Found model weight file: {model_file}") |
|
|
|
|
|
|
|
|
try: |
|
|
with open(model_file, 'rb') as f: |
|
|
|
|
|
f.seek(0, 2) |
|
|
file_size = f.tell() |
|
|
f.seek(0) |
|
|
|
|
|
|
|
|
header = f.read(8) |
|
|
if len(header) == 0: |
|
|
raise ValueError(f"File is empty: {model_file}") |
|
|
except Exception as e: |
|
|
raise Exception(f"Failed to read model file {model_file}: {str(e)}") |
|
|
|
|
|
|
|
|
if file_size < 1024: |
|
|
raise ValueError(f"Model file suspiciously small ({file_size} bytes). Possible corruption or incomplete download.") |
|
|
|
|
|
num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS) |
|
|
num_chunks = num_servers |
|
|
|
|
|
chunk_size = math.ceil(file_size / num_chunks) |
|
|
|
|
|
|
|
|
def format_size(size_bytes): |
|
|
if size_bytes >= 1024*1024*1024: |
|
|
return f"{size_bytes / (1024*1024*1024):.2f} GB ({size_bytes:,} bytes)" |
|
|
elif size_bytes >= 1024*1024: |
|
|
return f"{size_bytes / (1024*1024):.2f} MB ({size_bytes:,} bytes)" |
|
|
elif size_bytes >= 1024: |
|
|
return f"{size_bytes / 1024:.2f} KB ({size_bytes:,} bytes)" |
|
|
else: |
|
|
return f"{size_bytes:,} bytes" |
|
|
|
|
|
print(f"[INFO] Model file size: {format_size(file_size)}") |
|
|
print(f"[INFO] Creating {num_chunks} chunks of approximately {format_size(chunk_size)} each") |
|
|
|
|
|
|
|
|
os.makedirs(state.chunks_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
with open(model_file, 'rb') as f: |
|
|
chunk_sizes = [] |
|
|
for chunk_id in range(num_chunks): |
|
|
chunk_path = os.path.join(state.chunks_dir, f"chunk_{chunk_id}.bin") |
|
|
|
|
|
|
|
|
start_pos = chunk_id * chunk_size |
|
|
remaining = file_size - start_pos |
|
|
current_chunk_size = min(chunk_size, remaining) |
|
|
|
|
|
if current_chunk_size <= 0: |
|
|
break |
|
|
|
|
|
|
|
|
try: |
|
|
f.seek(start_pos) |
|
|
chunk_data = f.read(current_chunk_size) |
|
|
actual_chunk_size = len(chunk_data) |
|
|
|
|
|
if actual_chunk_size != current_chunk_size: |
|
|
print(f"[WARN] Chunk {chunk_id} size mismatch. Expected: {current_chunk_size}, Got: {actual_chunk_size}") |
|
|
|
|
|
with open(chunk_path, 'wb') as chunk_file: |
|
|
chunk_file.write(chunk_data) |
|
|
|
|
|
chunk_sizes.append(actual_chunk_size) |
|
|
print(f"[DEBUG] Chunk {chunk_id} data: First few bytes: {chunk_data[:20].hex()}") |
|
|
except Exception as e: |
|
|
raise Exception(f"Failed to process chunk {chunk_id} at offset {start_pos}: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cumulative = 0 |
|
|
for cid, c in state.model_chunks.items(): |
|
|
try: |
|
|
cumulative += int(c.config.get('shard_dim', c.config.get('size_bytes', 1))) |
|
|
except Exception: |
|
|
cumulative += 1 |
|
|
|
|
|
cfg = { |
|
|
"start_offset": start_pos, |
|
|
"size_bytes": current_chunk_size, |
|
|
"is_last_chunk": chunk_id == num_chunks - 1, |
|
|
"total_chunks": num_chunks, |
|
|
"original_file": os.path.basename(model_file), |
|
|
|
|
|
"vocab_offset": cumulative, |
|
|
|
|
|
|
|
|
"shard_dim": int(cfg.get('shard_dim', 1)) if isinstance(cfg := {} , dict) else 1 |
|
|
} |
|
|
|
|
|
state.model_chunks[chunk_id] = ModelChunk( |
|
|
chunk_id=chunk_id, |
|
|
files=[f"chunk_{chunk_id}.bin"], |
|
|
config=cfg, |
|
|
size_bytes=current_chunk_size, |
|
|
status="ready" |
|
|
) |
|
|
|
|
|
print(f"[INFO] Created chunk {chunk_id}: {format_size(current_chunk_size)} ({current_chunk_size:,} bytes)") |
|
|
|
|
|
|
|
|
total_size_actual = sum(chunk_sizes) |
|
|
if total_size_actual != file_size: |
|
|
print(f"[WARN] Total chunk size ({format_size(total_size_actual)}) differs from original file size ({format_size(file_size)})") |
|
|
print(f"[WARN] Difference: {format_size(abs(total_size_actual - file_size))}") |
|
|
|
|
|
|
|
|
avg_chunk_size = sum(chunk_sizes) / len(chunk_sizes) if chunk_sizes else 0 |
|
|
min_chunk_size = min(chunk_sizes) if chunk_sizes else 0 |
|
|
max_chunk_size = max(chunk_sizes) if chunk_sizes else 0 |
|
|
|
|
|
print(f"\n[INFO] Distribution Summary:") |
|
|
print(f"- Original file: {os.path.basename(model_file)}") |
|
|
print(f"- Total size: {format_size(file_size)} ({file_size:,} bytes)") |
|
|
print(f"- Number of chunks: {len(state.model_chunks)}") |
|
|
print(f"- Chunks directory: {state.chunks_dir}") |
|
|
print(f"- Average chunk size: {format_size(avg_chunk_size)}") |
|
|
print(f"- Smallest chunk: {format_size(min_chunk_size)}") |
|
|
print(f"- Largest chunk: {format_size(max_chunk_size)}") |
|
|
print(f"- Size variance: {((max_chunk_size - min_chunk_size) / avg_chunk_size * 100):.1f}%") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to split model weights: {str(e)}") |
|
|
return False |
|
|
|
|
|
|
|
|
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values()) |
|
|
num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_chunks = num_servers |
|
|
bytes_per_chunk = math.ceil(total_size_bytes / num_chunks) |
|
|
|
|
|
print(f"[INFO] Total model size: {total_size_bytes / (1024*1024*1024):.2f} GB") |
|
|
print(f"[INFO] Available servers: {num_servers}") |
|
|
print(f"[INFO] Creating {num_chunks} chunks") |
|
|
print(f"[INFO] Target chunk size: {bytes_per_chunk / (1024*1024):.2f} MB") |
|
|
|
|
|
current_chunk = [] |
|
|
current_chunk_size = 0 |
|
|
chunk_id = 0 |
|
|
chunk_sizes = [] |
|
|
|
|
|
|
|
|
sorted_weights = sorted( |
|
|
weights.items(), |
|
|
key=lambda x: x[1].nelement() * x[1].element_size(), |
|
|
reverse=True |
|
|
) |
|
|
|
|
|
for key, tensor in weights.items(): |
|
|
tensor_size = tensor.numel() |
|
|
|
|
|
|
|
|
tensor_size = tensor.nelement() * tensor.element_size() |
|
|
|
|
|
|
|
|
if (current_chunk_size + tensor_size > bytes_per_chunk and current_chunk) or \ |
|
|
(chunk_id == num_chunks - 1): |
|
|
|
|
|
|
|
|
chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors") |
|
|
chunk_weights = {k: weights[k] for k in current_chunk} |
|
|
torch.save(chunk_weights, chunk_path) |
|
|
|
|
|
|
|
|
chunk_total_size = sum(weights[k].nelement() * weights[k].element_size() |
|
|
for k in current_chunk) |
|
|
chunk_sizes.append(chunk_total_size) |
|
|
|
|
|
|
|
|
state.model_chunks[chunk_id] = ModelChunk( |
|
|
chunk_id=chunk_id, |
|
|
files=[f"chunk_{chunk_id}.safetensors"], |
|
|
config={ |
|
|
"weight_keys": current_chunk, |
|
|
"size_bytes": chunk_total_size, |
|
|
"num_parameters": sum(weights[k].nelement() for k in current_chunk), |
|
|
"input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0, |
|
|
"output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0, |
|
|
|
|
|
"vocab_offset": sum(int(c.config.get('shard_dim', 1)) for c in state.model_chunks.values()), |
|
|
|
|
|
"shard_dim": int(1) |
|
|
} |
|
|
) |
|
|
|
|
|
print(f"[INFO] Created chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, " |
|
|
f"{len(current_chunk)} tensors") |
|
|
|
|
|
|
|
|
current_chunk = [] |
|
|
current_chunk_size = 0 |
|
|
chunk_id += 1 |
|
|
|
|
|
|
|
|
if chunk_id == num_chunks - 1: |
|
|
remaining_tensors = [k for k, _ in sorted_weights if k not in sum([c.config["weight_keys"] |
|
|
for c in state.model_chunks.values()], [])] |
|
|
current_chunk.extend(remaining_tensors) |
|
|
continue |
|
|
|
|
|
|
|
|
current_chunk.append(key) |
|
|
current_chunk_size += tensor_size |
|
|
|
|
|
|
|
|
if current_chunk: |
|
|
chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors") |
|
|
chunk_weights = {k: weights[k] for k in current_chunk} |
|
|
torch.save(chunk_weights, chunk_path) |
|
|
|
|
|
|
|
|
chunk_total_size = sum(weights[k].nelement() * weights[k].element_size() |
|
|
for k in current_chunk) |
|
|
chunk_sizes.append(chunk_total_size) |
|
|
|
|
|
state.model_chunks[chunk_id] = ModelChunk( |
|
|
chunk_id=chunk_id, |
|
|
files=[f"chunk_{chunk_id}.safetensors"], |
|
|
config={ |
|
|
"weight_keys": current_chunk, |
|
|
"size_bytes": chunk_total_size, |
|
|
"num_parameters": sum(weights[k].nelement() for k in current_chunk), |
|
|
"input_size": weights[current_chunk[0]].size(1), |
|
|
"output_size": weights[current_chunk[-1]].size(0) |
|
|
} |
|
|
) |
|
|
|
|
|
print(f"[INFO] Created final chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, " |
|
|
f"{len(current_chunk)} tensors") |
|
|
|
|
|
|
|
|
total_size_actual = sum(chunk_sizes) |
|
|
size_std_dev = torch.tensor(chunk_sizes).std().item() / (1024*1024) |
|
|
size_mean = torch.tensor(chunk_sizes).mean().item() / (1024*1024) |
|
|
|
|
|
print(f"\n[INFO] Distribution Summary:") |
|
|
print(f"- Total model size: {total_size_actual / (1024*1024*1024):.2f} GB") |
|
|
print(f"- Number of chunks: {len(state.model_chunks)}") |
|
|
print(f"- Average chunk size: {size_mean:.2f} MB") |
|
|
print(f"- Chunk size std dev: {size_std_dev:.2f} MB") |
|
|
print(f"- Size variation: {(size_std_dev/size_mean*100):.1f}%") |
|
|
|
|
|
|
|
|
all_distributed = set(sum([c.config["weight_keys"] for c in state.model_chunks.values()], [])) |
|
|
if len(all_distributed) != len(weights): |
|
|
missing = set(weights.keys()) - all_distributed |
|
|
print(f"[WARN] Some weights were not distributed: {missing}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to split model weights: {str(e)}") |
|
|
return False |
|
|
|
|
|
async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict): |
|
|
"""Send a model chunk to a tensor server""" |
|
|
try: |
|
|
print(f"[INFO] Sending chunk {chunk_id} to server {server_url}") |
|
|
chunk_path = os.path.join(state.chunks_dir, f"chunk_{chunk_id}.bin") |
|
|
|
|
|
if not os.path.exists(chunk_path): |
|
|
raise Exception(f"Chunk file not found: {chunk_path}") |
|
|
|
|
|
|
|
|
chunk = state.model_chunks[chunk_id] |
|
|
chunk_data = { |
|
|
'chunk_id': chunk_id, |
|
|
'files': [os.path.basename(chunk_path)], |
|
|
'config': chunk.config |
|
|
} |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
|
|
|
async with session.post( |
|
|
f"{server_url}/load_chunk", |
|
|
json=chunk_data, |
|
|
timeout=Settings.TENSOR_SERVER_TIMEOUT |
|
|
) as response: |
|
|
if response.status != 200: |
|
|
error_msg = await response.text() |
|
|
raise Exception(f"Failed to register chunk: {error_msg}") |
|
|
|
|
|
result = await response.json() |
|
|
if not result.get("ready_for_data", False): |
|
|
raise Exception("Server not ready for chunk data") |
|
|
|
|
|
|
|
|
with open(chunk_path, 'rb') as f: |
|
|
chunk_file = f.read() |
|
|
|
|
|
form = aiohttp.FormData() |
|
|
form.add_field('file', |
|
|
chunk_file, |
|
|
filename=os.path.basename(chunk_path), |
|
|
content_type='application/octet-stream') |
|
|
|
|
|
async with session.post( |
|
|
f"{server_url}/upload_chunk_data/{chunk_id}", |
|
|
data=form, |
|
|
timeout=Settings.TENSOR_SERVER_TIMEOUT |
|
|
) as upload_response: |
|
|
if upload_response.status != 200: |
|
|
error_msg = await upload_response.text() |
|
|
raise Exception(f"Failed to upload chunk data: {error_msg}") |
|
|
|
|
|
upload_result = await upload_response.json() |
|
|
print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}") |
|
|
return False |
|
|
|
|
|
async def distribute_model_chunks(): |
|
|
"""Distribute model chunks across available tensor servers""" |
|
|
try: |
|
|
available_servers = [ |
|
|
server for server in state.tensor_servers.values() |
|
|
if server.status in ["ready", "busy"] and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD |
|
|
] |
|
|
|
|
|
min_required = Settings.get_min_servers_required() |
|
|
if len(available_servers) < min_required: |
|
|
raise Exception(f"Not enough healthy servers. Need {min_required}, got {len(available_servers)}") |
|
|
|
|
|
|
|
|
if not state.model_chunks or len(state.model_chunks) > len(available_servers) * 3: |
|
|
if not await split_model_weights(): |
|
|
raise Exception("Failed to split model weights") |
|
|
|
|
|
|
|
|
tasks = [] |
|
|
min_replicas = Settings.get_min_replica_count(len(available_servers)) |
|
|
chunks_per_server = len(state.model_chunks) / len(available_servers) |
|
|
print(f"[INFO] Distributing chunks with min {min_replicas} replicas per chunk") |
|
|
print(f"[INFO] Target chunks per server: {chunks_per_server:.1f}") |
|
|
|
|
|
|
|
|
for chunk_id, chunk in state.model_chunks.items(): |
|
|
|
|
|
target_replicas = max(min_replicas, |
|
|
int(chunks_per_server * len(available_servers) / len(state.model_chunks))) |
|
|
|
|
|
current_assignments = set(chunk.server_assignments) |
|
|
current_healthy = [url for url in current_assignments |
|
|
if state.tensor_servers[url].status in ["ready", "busy"]] |
|
|
|
|
|
|
|
|
chunk.server_assignments = current_healthy |
|
|
|
|
|
|
|
|
while len(chunk.server_assignments) < target_replicas: |
|
|
|
|
|
eligible_servers = [ |
|
|
server for server in available_servers |
|
|
if str(server.url) not in chunk.server_assignments |
|
|
and len(server.model_chunks) < (len(state.model_chunks) / len(available_servers) * 1.5) |
|
|
] |
|
|
|
|
|
if not eligible_servers: |
|
|
break |
|
|
|
|
|
|
|
|
eligible_servers.sort(key=lambda s: ( |
|
|
len(s.model_chunks), |
|
|
s.metrics.error_count, |
|
|
s.metrics.cpu_usage |
|
|
)) |
|
|
|
|
|
|
|
|
best_server = eligible_servers[0] |
|
|
chunk.server_assignments.append(str(best_server.url)) |
|
|
best_server.model_chunks.append(chunk_id) |
|
|
print(f"[INFO] Assigned chunk {chunk_id} to server {best_server.url}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to distribute model chunks: {str(e)}") |
|
|
return False |
|
|
|
|
|
async def monitor_tensor_servers(): |
|
|
"""Periodically check health and update metrics of all tensor servers""" |
|
|
while True: |
|
|
for server_url, server in state.tensor_servers.items(): |
|
|
try: |
|
|
|
|
|
is_healthy = await check_tensor_server_health(server_url) |
|
|
|
|
|
if not is_healthy: |
|
|
server.status = "error" |
|
|
server.metrics.error_count += 1 |
|
|
print(f"[WARN] Server {server_url} is unhealthy") |
|
|
continue |
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.get(f"{server_url}/metrics", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response: |
|
|
if response.status == 200: |
|
|
metrics = await response.json() |
|
|
server.metrics = ServerMetrics(**metrics) |
|
|
|
|
|
|
|
|
if server.metrics.error_count > Settings.MAX_ERROR_THRESHOLD: |
|
|
server.status = "degraded" |
|
|
elif server.metrics.cpu_usage > 90 or server.metrics.memory_usage > 90: |
|
|
server.status = "busy" |
|
|
else: |
|
|
server.status = "ready" |
|
|
|
|
|
server.last_heartbeat = datetime.now() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to monitor server {server_url}: {str(e)}") |
|
|
server.status = "error" |
|
|
server.metrics.last_error = str(e) |
|
|
server.metrics.error_count += 1 |
|
|
|
|
|
|
|
|
current_time = datetime.now() |
|
|
for server_url, server in state.tensor_servers.items(): |
|
|
if (current_time - server.last_heartbeat).seconds > Settings.SERVER_TIMEOUT: |
|
|
print(f"[WARN] Server {server_url} hasn't responded in {Settings.SERVER_TIMEOUT} seconds") |
|
|
server.status = "error" |
|
|
|
|
|
await asyncio.sleep(Settings.MONITORING_INTERVAL) |
|
|
|
|
|
def get_next_model_version(base_dir: str, model_name: str) -> int: |
|
|
"""Get the next available version number for the model""" |
|
|
existing_versions = [] |
|
|
model_base_dir = os.path.join(base_dir, model_name) |
|
|
if os.path.exists(model_base_dir): |
|
|
for d in os.listdir(model_base_dir): |
|
|
if d.startswith('v') and d[1:].isdigit(): |
|
|
existing_versions.append(int(d[1:])) |
|
|
return max(existing_versions + [0]) + 1 |
|
|
|
|
|
def check_existing_model(model_path: str) -> bool: |
|
|
"""Check if a model exists and has required files""" |
|
|
if not os.path.exists(model_path): |
|
|
return False |
|
|
|
|
|
|
|
|
required_files = ['config.json'] |
|
|
model_files = os.listdir(model_path) |
|
|
|
|
|
|
|
|
has_weights = any(f.endswith(('.bin', '.safetensors')) for f in model_files) |
|
|
|
|
|
return all(f in model_files for f in required_files) and has_weights |
|
|
|
|
|
async def download_model_files(): |
|
|
"""Downloads the model files using Hugging Face Hub API""" |
|
|
try: |
|
|
print(f"[INFO] Processing model from {Settings.MODEL_REPO}...") |
|
|
|
|
|
|
|
|
required_packages = ["huggingface_hub", "requests", "tqdm"] |
|
|
for package in required_packages: |
|
|
try: |
|
|
__import__(package) |
|
|
except ImportError: |
|
|
print(f"[INFO] Installing {package}...") |
|
|
import subprocess |
|
|
subprocess.check_call(["pip", "install", package]) |
|
|
|
|
|
from huggingface_hub import hf_hub_download, snapshot_download, HfFolder |
|
|
import requests |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
models_dir = os.path.join(os.getcwd(), "models") |
|
|
os.makedirs(models_dir, exist_ok=True) |
|
|
print(f"[INFO] Models directory: {models_dir}") |
|
|
|
|
|
|
|
|
repo_id = "/".join(Settings.MODEL_REPO.split('/')[-2:]) |
|
|
model_name = repo_id.split('/')[-1] |
|
|
|
|
|
|
|
|
version = get_next_model_version(models_dir, model_name) |
|
|
model_base_dir = os.path.join(models_dir, model_name) |
|
|
model_version_dir = os.path.join(model_base_dir, f"v{version}") |
|
|
|
|
|
|
|
|
def download_file(url, filename): |
|
|
response = requests.get(url, stream=True) |
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
with open(filename, 'wb') as f, tqdm( |
|
|
desc=os.path.basename(filename), |
|
|
total=total_size, |
|
|
unit='iB', |
|
|
unit_scale=True, |
|
|
unit_divisor=1024, |
|
|
) as pbar: |
|
|
for data in response.iter_content(chunk_size=1024): |
|
|
size = f.write(data) |
|
|
pbar.update(size) |
|
|
|
|
|
|
|
|
if version > 1: |
|
|
prev_version_dir = os.path.join(model_base_dir, f"v{version-1}") |
|
|
if check_existing_model(prev_version_dir): |
|
|
print(f"[INFO] Using existing model from {prev_version_dir}") |
|
|
model_path = prev_version_dir |
|
|
state.is_model_loaded = True |
|
|
else: |
|
|
|
|
|
os.makedirs(model_version_dir, exist_ok=True) |
|
|
model_path = model_version_dir |
|
|
else: |
|
|
|
|
|
os.makedirs(model_version_dir, exist_ok=True) |
|
|
model_path = model_version_dir |
|
|
|
|
|
if not state.is_model_loaded: |
|
|
try: |
|
|
print(f"[INFO] Downloading model files from {repo_id}...") |
|
|
|
|
|
|
|
|
print("[INFO] Downloading all model files (this may take a while)...") |
|
|
|
|
|
|
|
|
|
|
|
model_path = snapshot_download( |
|
|
repo_id=repo_id, |
|
|
local_dir=model_path, |
|
|
allow_patterns=["*.bin", "*.safetensors", "*.json", "*.txt", "tokenizer.model"], |
|
|
ignore_patterns=["*.msgpack", "*.onnx"], |
|
|
force_download=True |
|
|
) |
|
|
|
|
|
print(f"[INFO] All files downloaded to {model_path}") |
|
|
state.is_model_loaded = True |
|
|
|
|
|
except Exception as e: |
|
|
raise Exception(f"Failed to download model files: {str(e)}") |
|
|
|
|
|
|
|
|
state.model_path = model_path |
|
|
state.chunks_dir = os.path.join(model_path, "chunks") |
|
|
os.makedirs(state.chunks_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
config_path = os.path.join(model_path, "config.json") |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r') as f: |
|
|
state.model_config = json.load(f) |
|
|
print("[INFO] Loaded model configuration") |
|
|
print(f"[INFO] Model type: {state.model_config.get('model_type', 'unknown')}") |
|
|
print(f"[INFO] Architecture: {state.model_config.get('architectures', ['unknown'])[0]}") |
|
|
else: |
|
|
print("[WARN] No config.json found in model directory") |
|
|
|
|
|
|
|
|
print("[INFO] Scanning for model files...") |
|
|
for root, _, files in os.walk(model_path): |
|
|
for file in files: |
|
|
if file.endswith(('.bin', '.json', '.safetensors')): |
|
|
file_path = os.path.join(root, file) |
|
|
state.model_files[file] = file_path |
|
|
print(f"[INFO] Found model file: {file}") |
|
|
|
|
|
if state.model_files: |
|
|
state.is_model_loaded = True |
|
|
print(f"[INFO] Model files found successfully! Total files: {len(state.model_files)}") |
|
|
print(f"[INFO] Model location: {model_path}") |
|
|
return True |
|
|
else: |
|
|
raise ValueError("No model files were found in the repository") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to process model files: {e}") |
|
|
state.is_model_loaded = False |
|
|
raise |
|
|
|
|
|
async def check_tensor_server_health(url: HttpUrl) -> bool: |
|
|
"""Checks if a tensor server is healthy""" |
|
|
try: |
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.get(f"{url}/health", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response: |
|
|
return response.status == 200 |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
async def execute_tensor_operation(operation_id: str, server_url: HttpUrl, operation: str, data: Dict): |
|
|
"""Execute an operation on a tensor server and wait for results""" |
|
|
try: |
|
|
async with aiohttp.ClientSession() as session: |
|
|
|
|
|
async with session.post( |
|
|
f"{server_url}/{operation}", |
|
|
json=data, |
|
|
timeout=Settings.TENSOR_SERVER_TIMEOUT |
|
|
) as response: |
|
|
if response.status != 200: |
|
|
error_msg = await response.text() |
|
|
raise HTTPException( |
|
|
status_code=response.status, |
|
|
detail=f"Operation failed on server {server_url}: {error_msg}" |
|
|
) |
|
|
|
|
|
initial_response = await response.json() |
|
|
if initial_response.get("status") == "completed": |
|
|
|
|
|
state.operation_results[operation_id] = initial_response |
|
|
return initial_response |
|
|
|
|
|
|
|
|
while True: |
|
|
await asyncio.sleep(1) |
|
|
async with session.get( |
|
|
f"{server_url}/operation/{initial_response['operation_id']}", |
|
|
timeout=Settings.TENSOR_SERVER_TIMEOUT |
|
|
) as status_response: |
|
|
if status_response.status != 200: |
|
|
raise HTTPException( |
|
|
status_code=status_response.status, |
|
|
detail=f"Failed to get operation status from {server_url}" |
|
|
) |
|
|
|
|
|
status_data = await status_response.json() |
|
|
if status_data["status"] in ["completed", "failed"]: |
|
|
state.operation_results[operation_id] = status_data |
|
|
if status_data["status"] == "failed": |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Operation failed on server {server_url}: {status_data.get('error')}" |
|
|
) |
|
|
return status_data |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
raise HTTPException( |
|
|
status_code=504, |
|
|
detail=f"Operation timed out on server {server_url}" |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Error executing operation on {server_url}: {str(e)}" |
|
|
) |
|
|
|
|
|
@app.post("/execute/{operation}") |
|
|
async def execute_operation(operation: str, data: Dict): |
|
|
"""Execute an operation across tensor servers and collect results""" |
|
|
operation_id = f"{operation}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{len(state.operation_results)}" |
|
|
|
|
|
|
|
|
available_servers = [ |
|
|
server for server in state.tensor_servers.values() |
|
|
if server.status in ["ready", "busy"] |
|
|
and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD |
|
|
] |
|
|
|
|
|
if not available_servers: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="No available tensor servers" |
|
|
) |
|
|
|
|
|
|
|
|
tasks = [] |
|
|
for server in available_servers: |
|
|
if operation in ["compute", "forward"]: |
|
|
|
|
|
required_chunks = data.get("required_chunks", []) |
|
|
if not all(chunk_id in server.model_chunks for chunk_id in required_chunks): |
|
|
continue |
|
|
|
|
|
task = asyncio.create_task( |
|
|
execute_tensor_operation( |
|
|
f"{operation_id}_{server.url}", |
|
|
server.url, |
|
|
operation, |
|
|
data |
|
|
) |
|
|
) |
|
|
tasks.append(task) |
|
|
state.pending_operations[f"{operation_id}_{server.url}"] = task |
|
|
|
|
|
if not tasks: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="No servers available with required model chunks" |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
|
|
|
aggregated_result = { |
|
|
"operation_id": operation_id, |
|
|
"status": "completed", |
|
|
"server_results": results, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
for task_id in list(state.pending_operations.keys()): |
|
|
if task_id.startswith(operation_id): |
|
|
del state.pending_operations[task_id] |
|
|
|
|
|
return aggregated_result |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
for task in tasks: |
|
|
if not task.done(): |
|
|
task.cancel() |
|
|
|
|
|
|
|
|
for task_id in list(state.pending_operations.keys()): |
|
|
if task_id.startswith(operation_id): |
|
|
del state.pending_operations[task_id] |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Operation failed: {str(e)}" |
|
|
) |
|
|
|
|
|
@app.get("/operation/{operation_id}") |
|
|
async def get_operation_status(operation_id: str): |
|
|
"""Get the status of an operation""" |
|
|
|
|
|
results = { |
|
|
k: v for k, v in state.operation_results.items() |
|
|
if k.startswith(operation_id) |
|
|
} |
|
|
|
|
|
if results: |
|
|
return { |
|
|
"operation_id": operation_id, |
|
|
"status": "completed", |
|
|
"results": results |
|
|
} |
|
|
|
|
|
|
|
|
pending = { |
|
|
k: "running" for k in state.pending_operations.keys() |
|
|
if k.startswith(operation_id) |
|
|
} |
|
|
|
|
|
if pending: |
|
|
return { |
|
|
"operation_id": operation_id, |
|
|
"status": "running", |
|
|
"pending_servers": list(pending.keys()) |
|
|
} |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Operation {operation_id} not found" |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "running", |
|
|
"model_loaded": state.is_model_loaded, |
|
|
"registered_servers": len(state.tensor_servers), |
|
|
"downloaded_files": len(state.model_files), |
|
|
"config_loaded": bool(state.model_config) |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Detailed health check""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_loaded": state.is_model_loaded, |
|
|
"registered_servers": len(state.tensor_servers), |
|
|
"downloaded_files": list(state.model_files.keys()), |
|
|
"config_loaded": bool(state.model_config), |
|
|
"model_type": state.model_config.get("model_type", "unknown") |
|
|
} |
|
|
|
|
|
@app.post("/register_tensor_server") |
|
|
async def register_tensor_server(server_url: HttpUrl): |
|
|
"""Register a new tensor server""" |
|
|
if not await check_tensor_server_health(server_url): |
|
|
raise HTTPException(status_code=400, detail="Tensor server is not healthy") |
|
|
|
|
|
state.tensor_servers[str(server_url)] = TensorServer(url=server_url) |
|
|
print(f"[INFO] Registered new tensor server at {server_url}") |
|
|
|
|
|
|
|
|
if state.is_model_loaded: |
|
|
print(f"[INFO] Model is loaded, starting distribution for new server {server_url}") |
|
|
try: |
|
|
|
|
|
if not state.model_chunks: |
|
|
if await split_model_weights(): |
|
|
print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") |
|
|
else: |
|
|
print("[ERROR] Failed to split model weights") |
|
|
|
|
|
|
|
|
if await distribute_model_chunks(): |
|
|
print("[INFO] Successfully distributed chunks to tensor servers") |
|
|
else: |
|
|
print("[ERROR] Failed to distribute chunks") |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Distribution error during server registration: {str(e)}") |
|
|
|
|
|
return { |
|
|
"status": "registered", |
|
|
"registered_servers": len(state.tensor_servers), |
|
|
"server_id": str(server_url), |
|
|
"model_loaded": state.is_model_loaded, |
|
|
"chunks_distributed": len(state.model_chunks) if state.model_chunks else 0 |
|
|
} |
|
|
|
|
|
@app.delete("/unregister_tensor_server") |
|
|
async def unregister_tensor_server(server_url: HttpUrl): |
|
|
"""Unregister a tensor server""" |
|
|
if str(server_url) in state.tensor_servers: |
|
|
|
|
|
for chunk in state.model_chunks.values(): |
|
|
if str(server_url) in chunk.server_assignments: |
|
|
chunk.server_assignments.remove(str(server_url)) |
|
|
|
|
|
del state.tensor_servers[str(server_url)] |
|
|
print(f"[INFO] Unregistered tensor server at {server_url}") |
|
|
|
|
|
|
|
|
await distribute_model_chunks() |
|
|
return {"status": "unregistered"} |
|
|
raise HTTPException(status_code=404, detail="Server not found") |
|
|
|
|
|
@app.get("/server/{server_url}/chunks") |
|
|
async def get_server_chunks(server_url: HttpUrl): |
|
|
"""Get the chunks assigned to a specific server""" |
|
|
if str(server_url) not in state.tensor_servers: |
|
|
raise HTTPException(status_code=404, detail="Server not found") |
|
|
|
|
|
server = state.tensor_servers[str(server_url)] |
|
|
assigned_chunks = [ |
|
|
state.model_chunks[chunk_id] |
|
|
for chunk_id in server.model_chunks |
|
|
] |
|
|
|
|
|
return { |
|
|
"server_status": server.status, |
|
|
"assigned_chunks": assigned_chunks, |
|
|
"metrics": server.metrics.dict() |
|
|
} |
|
|
|
|
|
@app.post("/redistribute") |
|
|
async def redistribute_chunks(): |
|
|
"""Manually trigger redistribution of model chunks""" |
|
|
success = await distribute_model_chunks() |
|
|
if not success: |
|
|
raise HTTPException(status_code=500, detail="Failed to redistribute chunks") |
|
|
|
|
|
return { |
|
|
"status": "redistributed", |
|
|
"chunk_assignments": { |
|
|
chunk_id: chunk.server_assignments |
|
|
for chunk_id, chunk in state.model_chunks.items() |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/chunks/{chunk_id}/status") |
|
|
async def get_chunk_status(chunk_id: int): |
|
|
"""Get the status and assignments of a specific chunk""" |
|
|
if chunk_id not in state.model_chunks: |
|
|
raise HTTPException(status_code=404, detail="Chunk not found") |
|
|
|
|
|
chunk = state.model_chunks[chunk_id] |
|
|
return { |
|
|
"chunk_id": chunk_id, |
|
|
"status": chunk.status, |
|
|
"server_assignments": chunk.server_assignments, |
|
|
"metrics": chunk.metrics |
|
|
} |
|
|
|
|
|
@app.post("/initialize") |
|
|
async def initialize_system(): |
|
|
"""Download model files and prepare for distribution""" |
|
|
await download_model_files() |
|
|
|
|
|
|
|
|
files_status = {} |
|
|
total_size = 0 |
|
|
for filename, filepath in state.model_files.items(): |
|
|
exists = os.path.exists(filepath) |
|
|
if exists: |
|
|
size = os.path.getsize(filepath) |
|
|
total_size += size |
|
|
files_status[filename] = {"exists": exists, "size_bytes": size} |
|
|
else: |
|
|
files_status[filename] = {"exists": exists, "size_bytes": 0} |
|
|
|
|
|
|
|
|
distribution_status = "not_started" |
|
|
if state.tensor_servers: |
|
|
print("[INFO] Starting automatic model distribution...") |
|
|
try: |
|
|
|
|
|
if await split_model_weights(): |
|
|
print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") |
|
|
|
|
|
if await distribute_model_chunks(): |
|
|
print("[INFO] Successfully distributed chunks to tensor servers") |
|
|
distribution_status = "completed" |
|
|
else: |
|
|
print("[ERROR] Failed to distribute chunks") |
|
|
distribution_status = "distribution_failed" |
|
|
else: |
|
|
print("[ERROR] Failed to split model weights") |
|
|
distribution_status = "split_failed" |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Distribution error: {str(e)}") |
|
|
distribution_status = f"error: {str(e)}" |
|
|
else: |
|
|
print("[INFO] No tensor servers registered yet. Will distribute when servers register.") |
|
|
|
|
|
return { |
|
|
"status": "initialized", |
|
|
"model_loaded": state.is_model_loaded, |
|
|
"files_status": files_status, |
|
|
"total_size_bytes": total_size, |
|
|
"config_loaded": bool(state.model_config), |
|
|
"model_type": state.model_config.get("model_type", "unknown"), |
|
|
"architecture": state.model_config.get("architectures", ["unknown"])[0], |
|
|
"distribution_status": distribution_status, |
|
|
"registered_servers": len(state.tensor_servers), |
|
|
"chunks_created": len(state.model_chunks) if state.model_chunks else 0 |
|
|
} |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize the server and start distribution""" |
|
|
print("[INFO] Initializing system...") |
|
|
try: |
|
|
|
|
|
await initialize_system() |
|
|
print("[INFO] Model initialization complete") |
|
|
|
|
|
|
|
|
if await split_model_weights(): |
|
|
print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") |
|
|
|
|
|
|
|
|
print("[INFO] Starting chunk distribution...") |
|
|
distribution_tasks = [] |
|
|
|
|
|
|
|
|
for chunk_id, chunk in state.model_chunks.items(): |
|
|
|
|
|
server_index = chunk_id % len(Settings.TENSOR_SERVER_URLS) |
|
|
server_url = Settings.TENSOR_SERVER_URLS[server_index] |
|
|
|
|
|
task = asyncio.create_task( |
|
|
send_chunk_to_server(server_url, chunk_id, {"chunk_id": chunk_id}) |
|
|
) |
|
|
distribution_tasks.append(task) |
|
|
print(f"[INFO] Sending chunk {chunk_id} to {server_url}") |
|
|
|
|
|
try: |
|
|
chunk.server_assignments.append(server_url) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if distribution_tasks: |
|
|
print(f"[INFO] Distributing {len(distribution_tasks)} chunks...") |
|
|
results = await asyncio.gather(*distribution_tasks, return_exceptions=True) |
|
|
success_count = sum(1 for r in results if r is True) |
|
|
print(f"[INFO] Successfully distributed {success_count} chunks out of {len(distribution_tasks)} attempts") |
|
|
else: |
|
|
print("[ERROR] Failed to split model weights") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] Startup error: {str(e)}") |
|
|
|
|
|
print("[INFO] Startup complete") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", 8000)) |
|
|
print(f"[INFO] Starting controller server on port {port}") |
|
|
print(f"[INFO] API Documentation available at http://localhost:{port}/docs") |
|
|
|
|
|
uvicorn.run( |
|
|
"controller_server_new:app", |
|
|
host="0.0.0.0", |
|
|
port=port, |
|
|
reload=False |
|
|
) |