Fred808 commited on
Commit
8a520fb
·
verified ·
1 Parent(s): 4d4fb80

Update tensor_server.py

Browse files
Files changed (1) hide show
  1. tensor_server.py +755 -270
tensor_server.py CHANGED
@@ -1,271 +1,756 @@
1
- import os
2
- import json
3
- import torch
4
- import psutil
5
- import asyncio
6
- from datetime import datetime
7
- from typing import Dict, List, Optional
8
- from fastapi import FastAPI, HTTPException
9
- from pydantic import BaseModel
10
- import uvicorn
11
- import numpy as np
12
-
13
- # ===== Config =====
14
- class Settings:
15
- # Server configuration
16
- HOST = "0.0.0.0" # Listen on all interfaces
17
- PORT = 8001
18
- SERVER_ID = os.getenv("SERVER_ID", "tensor1") # Unique ID for this tensor server
19
-
20
- # The IP or hostname where this tensor server is accessible
21
- PUBLIC_URL = os.getenv("PUBLIC_URL", f"https://fred808-ilob.hf.space")
22
-
23
- # URLs for other services (should be actual IP addresses or hostnames)
24
- CONTROLLER_URL = os.getenv("CONTROLLER_URL", "http://192.168.1.100:8000")
25
- AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
26
-
27
- # Model settings
28
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
- MAX_BATCH_SIZE = 32
30
- METRICS_UPDATE_INTERVAL = 5 # seconds
31
- MODEL_DIR = "model_chunks"
32
-
33
- @classmethod
34
- def from_env(cls):
35
- """Load settings from environment variables"""
36
- cls.HOST = os.getenv("TENSOR_HOST", cls.HOST)
37
- cls.PORT = int(os.getenv("TENSOR_PORT", cls.PORT))
38
- cls.SERVER_ID = os.getenv("SERVER_ID", cls.SERVER_ID)
39
- cls.CONTROLLER_URL = os.getenv("CONTROLLER_URL", cls.CONTROLLER_URL)
40
- cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", cls.AGGREGATOR_URL)
41
- return cls
42
-
43
- # ===== Models =====
44
- class ModelChunk(BaseModel):
45
- """Represents a received model chunk configuration"""
46
- chunk_id: int
47
- files: List[str]
48
- config: Dict
49
-
50
- class InferenceRequest(BaseModel):
51
- """Represents an inference request"""
52
- inputs: List[List[float]]
53
- batch_size: Optional[int] = None
54
-
55
- class MetricsData(BaseModel):
56
- """Server metrics data"""
57
- cpu_usage: float
58
- memory_usage: float
59
- gpu_usage: Optional[float]
60
- active_requests: int
61
- total_requests: int
62
- average_response_time: float
63
- last_error: Optional[str]
64
- error_count: int
65
-
66
- # ===== FastAPI App =====
67
- app = FastAPI(
68
- title="Tensor Server",
69
- description="Handles model chunk computations",
70
- version="1.0.0"
71
- )
72
-
73
- # ===== State =====
74
- class ServerState:
75
- def __init__(self):
76
- self.loaded_chunks: Dict[int, torch.nn.Module] = {}
77
- self.active_requests: int = 0
78
- self.total_requests: int = 0
79
- self.request_times: List[float] = []
80
- self.error_count: int = 0
81
- self.last_error: Optional[str] = None
82
- self.is_computing: bool = False
83
-
84
- state = ServerState()
85
-
86
- # ===== Metrics Collection =====
87
- async def collect_metrics() -> MetricsData:
88
- """Collect current server metrics"""
89
- # CPU and memory metrics
90
- cpu_usage = psutil.cpu_percent()
91
- memory = psutil.virtual_memory()
92
- memory_usage = memory.percent
93
-
94
- # GPU metrics if available
95
- gpu_usage = None
96
- if torch.cuda.is_available():
97
- try:
98
- gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
99
- except:
100
- pass
101
-
102
- # Calculate average response time
103
- avg_response_time = sum(state.request_times) / len(state.request_times) if state.request_times else 0
104
-
105
- return MetricsData(
106
- cpu_usage=cpu_usage,
107
- memory_usage=memory_usage,
108
- gpu_usage=gpu_usage,
109
- active_requests=state.active_requests,
110
- total_requests=state.total_requests,
111
- average_response_time=avg_response_time,
112
- last_error=state.last_error,
113
- error_count=state.error_count
114
- )
115
-
116
- async def update_metrics_loop():
117
- """Background task to update metrics periodically"""
118
- while True:
119
- try:
120
- metrics = await collect_metrics()
121
- # Store metrics for health checks
122
- state.current_metrics = metrics
123
- except Exception as e:
124
- print(f"[ERROR] Failed to update metrics: {str(e)}")
125
- await asyncio.sleep(Settings.METRICS_UPDATE_INTERVAL)
126
-
127
- # ===== Helper Functions =====
128
- def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
129
- """Load a model chunk into memory"""
130
- try:
131
- # Create chunk directory if it doesn't exist
132
- os.makedirs(Settings.MODEL_DIR, exist_ok=True)
133
-
134
- # Get chunk configuration
135
- input_size = chunk.config["input_size"]
136
- output_size = chunk.config["output_size"]
137
- weight_keys = chunk.config["weight_keys"]
138
-
139
- # Create a simple linear transformation for this chunk
140
- chunk_model = torch.nn.Linear(input_size, output_size)
141
- chunk_model = chunk_model.to(Settings.DEVICE)
142
-
143
- # Load the weights
144
- chunk_file = os.path.join(Settings.MODEL_DIR, chunk.files[0])
145
- if os.path.exists(chunk_file):
146
- weights = torch.load(chunk_file, map_location=Settings.DEVICE)
147
-
148
- # Initialize weights from the loaded state dict
149
- with torch.no_grad():
150
- # Combine weights if multiple keys
151
- if len(weight_keys) > 1:
152
- combined_weight = torch.cat([weights[k] for k in weight_keys], dim=0)
153
- chunk_model.weight.copy_(combined_weight)
154
- else:
155
- chunk_model.weight.copy_(weights[weight_keys[0]])
156
-
157
- return chunk_model
158
-
159
- except Exception as e:
160
- raise Exception(f"Failed to load chunk: {str(e)}")
161
-
162
- async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
163
- """Process input tensor through the specified chunk"""
164
- if chunk_id not in state.loaded_chunks:
165
- raise HTTPException(status_code=400, detail=f"Chunk {chunk_id} not loaded")
166
-
167
- chunk_model = state.loaded_chunks[chunk_id]
168
- with torch.no_grad():
169
- outputs = chunk_model(inputs)
170
- return outputs
171
-
172
- # ===== API Endpoints =====
173
- @app.get("/health")
174
- async def health_check():
175
- """Health check endpoint"""
176
- metrics = await collect_metrics()
177
- return {
178
- "status": "healthy",
179
- "device": Settings.DEVICE,
180
- "loaded_chunks": list(state.loaded_chunks.keys()),
181
- "metrics": metrics.dict()
182
- }
183
-
184
- @app.get("/metrics")
185
- async def get_metrics():
186
- """Get current server metrics"""
187
- return await collect_metrics()
188
-
189
- @app.post("/load_chunk")
190
- async def load_model_chunk(chunk: ModelChunk):
191
- """Load a model chunk into memory"""
192
- try:
193
- # Load the chunk
194
- chunk_model = load_chunk(chunk)
195
- state.loaded_chunks[chunk.chunk_id] = chunk_model
196
-
197
- return {
198
- "status": "loaded",
199
- "chunk_id": chunk.chunk_id,
200
- "device": str(next(chunk_model.parameters()).device)
201
- }
202
-
203
- except Exception as e:
204
- state.error_count += 1
205
- state.last_error = str(e)
206
- raise HTTPException(status_code=500, detail=str(e))
207
-
208
- @app.post("/compute/{chunk_id}")
209
- async def compute(chunk_id: int, request: InferenceRequest):
210
- """Perform computation on inputs using specified chunk"""
211
- try:
212
- start_time = datetime.now()
213
- state.active_requests += 1
214
- state.total_requests += 1
215
-
216
- # Convert inputs to tensor
217
- inputs = torch.tensor(request.inputs, dtype=torch.float32, device=Settings.DEVICE)
218
-
219
- # Split into batches if needed
220
- batch_size = request.batch_size or Settings.MAX_BATCH_SIZE
221
- if len(inputs) > batch_size:
222
- batches = torch.split(inputs, batch_size)
223
- outputs = []
224
- for batch in batches:
225
- batch_output = await process_tensor(chunk_id, batch)
226
- outputs.append(batch_output)
227
- output_tensor = torch.cat(outputs, dim=0)
228
- else:
229
- output_tensor = await process_tensor(chunk_id, inputs)
230
-
231
- # Convert output to list
232
- output_list = output_tensor.cpu().numpy().tolist()
233
-
234
- # Update metrics
235
- end_time = datetime.now()
236
- processing_time = (end_time - start_time).total_seconds()
237
- state.request_times.append(processing_time)
238
- # Keep only last 100 request times
239
- state.request_times = state.request_times[-100:]
240
-
241
- return {
242
- "outputs": output_list,
243
- "processing_time": processing_time
244
- }
245
-
246
- except Exception as e:
247
- state.error_count += 1
248
- state.last_error = str(e)
249
- raise HTTPException(status_code=500, detail=str(e))
250
-
251
- finally:
252
- state.active_requests -= 1
253
-
254
- @app.on_event("startup")
255
- async def startup_event():
256
- """Start background tasks"""
257
- asyncio.create_task(update_metrics_loop())
258
-
259
- # ===== Main Execution =====
260
- if __name__ == "__main__":
261
- port = int(os.getenv("PORT", 8001)) # Default to 8001 to avoid conflict with controller
262
- print(f"[INFO] Starting tensor server on port {port}")
263
- print(f"[INFO] Using device: {Settings.DEVICE}")
264
- print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
265
-
266
- uvicorn.run(
267
- "tensor_server:app",
268
- host="0.0.0.0",
269
- port=port,
270
- reload=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  )
 
1
+ import os
2
+ import json
3
+ from datetime import datetime
4
+ import asyncio
5
+ import aiohttp
6
+ from typing import Dict, List, Optional
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel, HttpUrl
9
+ import uvicorn
10
+ from git_clone import clone_repository
11
+
12
+ # ===== CONFIG =====
13
+ class Settings:
14
+ # Server URLs and Ports
15
+ CONTROLLER_HOST = "0.0.0.0" # Listen on all interfaces
16
+ CONTROLLER_PORT = 8000
17
+ # This should be the actual IP or hostname where controller is accessible
18
+ CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", "http://192.168.1.100:8000")
19
+
20
+ # List of tensor server URLs - should be actual IP addresses or hostnames
21
+ TENSOR_SERVER_URLS = os.getenv("TENSOR_SERVER_URLS", "").split(",") or [
22
+ "https://fred808-ilob.hf.space", # Example IP for tensor server 1
23
+ "https://fred808-tserv.hf.space", # Example IP for tensor server 2
24
+ "https://fred808-tserve2.hf.space" # Example IP for tensor server 3
25
+ ]
26
+
27
+ # Aggregator settings - should be actual IP or hostname
28
+ AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
+
30
+ # Model settings
31
+ MODEL_REPO = "https://huggingface.co/microsoft/Florence-2-large"
32
+
33
+ # Server settings
34
+ TENSOR_SERVER_TIMEOUT = 30 # seconds
35
+ MAX_ERROR_THRESHOLD = 5 # maximum number of errors
36
+ SERVER_TIMEOUT = 60 # seconds before marking as error
37
+ MONITORING_INTERVAL = 15 # seconds between health checks
38
+
39
+ # Dynamic distribution settings
40
+ @classmethod
41
+ def get_optimal_chunk_size(cls, total_params: int, num_servers: int) -> int:
42
+ """Calculate optimal chunk size based on number of servers"""
43
+ # Aim for 2-3 chunks per server for better parallelism
44
+ target_chunks = num_servers * 2
45
+ return max(1, total_params // target_chunks)
46
+
47
+ @classmethod
48
+ def get_min_servers_required(cls) -> int:
49
+ """Dynamically calculate minimum servers needed based on registered servers"""
50
+ return max(2, len(cls.TENSOR_SERVER_URLS) // 3) # At least 1/3 of registered servers
51
+
52
+ @classmethod
53
+ def get_min_replica_count(cls, num_servers: int) -> int:
54
+ """Calculate minimum replicas based on server count"""
55
+ return max(2, num_servers // 4) # At least 25% of servers should have each chunk
56
+
57
+ # Tokenizer settings
58
+ MAX_SEQUENCE_LENGTH = 2048
59
+ VOCAB_SIZE = 50257
60
+
61
+ @classmethod
62
+ def from_env(cls):
63
+ """Load settings from environment variables"""
64
+ cls.CONTROLLER_HOST = os.getenv("CONTROLLER_HOST", cls.CONTROLLER_HOST)
65
+ cls.CONTROLLER_PORT = int(os.getenv("CONTROLLER_PORT", cls.CONTROLLER_PORT))
66
+ cls.CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", cls.CONTROLLER_BASE_URL)
67
+
68
+ # Load tensor server URLs from environment
69
+ tensor_urls = os.getenv("TENSOR_SERVER_URLS")
70
+ if tensor_urls:
71
+ cls.TENSOR_SERVER_URLS = tensor_urls.split(",")
72
+
73
+ cls.AGGREGATOR_HOST = os.getenv("AGGREGATOR_HOST", cls.AGGREGATOR_HOST)
74
+ cls.AGGREGATOR_PORT = int(os.getenv("AGGREGATOR_PORT", cls.AGGREGATOR_PORT))
75
+ cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL",
76
+ f"http://{cls.AGGREGATOR_HOST}:{cls.AGGREGATOR_PORT}")
77
+
78
+ return cls
79
+
80
+ # ===== State Models =====
81
+ class ServerMetrics(BaseModel):
82
+ """Metrics for tensor server performance and load"""
83
+ cpu_usage: float = 0.0
84
+ memory_usage: float = 0.0
85
+ gpu_usage: Optional[float] = None
86
+ active_requests: int = 0
87
+ total_requests: int = 0
88
+ average_response_time: float = 0.0
89
+ last_error: Optional[str] = None
90
+ error_count: int = 0
91
+
92
+ class TensorServer(BaseModel):
93
+ """Represents a registered tensor server"""
94
+ url: HttpUrl
95
+ status: str = "initializing" # initializing, ready, busy, error, degraded
96
+ last_heartbeat: datetime = datetime.now()
97
+ model_chunks: List[int] = [] # List of chunk IDs assigned to this server
98
+ metrics: ServerMetrics = ServerMetrics()
99
+ version: str = "1.0.0"
100
+ capabilities: Dict[str, bool] = {
101
+ "gpu_available": False,
102
+ "quantization_support": False,
103
+ "tensor_parallelism": False
104
+ }
105
+
106
+ class ModelChunk(BaseModel):
107
+ """Represents a chunk of the model to be sent to a tensor server"""
108
+ chunk_id: int
109
+ files: List[str] # files included in this chunk
110
+ config: Dict # configuration for this chunk
111
+ size_bytes: int = 0
112
+ server_assignments: List[str] = [] # URLs of servers holding this chunk
113
+ status: str = "unassigned" # unassigned, assigned, loaded, error
114
+ metrics: Dict[str, float] = {
115
+ "load_time": 0.0,
116
+ "memory_usage": 0.0,
117
+ "average_inference_time": 0.0
118
+ }
119
+
120
+ # ===== FastAPI App =====
121
+ app = FastAPI(
122
+ title="Florence-2 Model Controller",
123
+ description="Controls model distribution across tensor servers",
124
+ version="1.0.0"
125
+ )
126
+
127
+ # ===== Global State =====
128
+ class ControllerState:
129
+ def __init__(self):
130
+ self.model_files: Dict[str, str] = {} # Mapping of filename to file path
131
+ self.model_config: Dict = {} # Model configuration
132
+ self.tensor_servers: Dict[str, TensorServer] = {}
133
+ self.model_chunks: Dict[int, ModelChunk] = {}
134
+ self.is_model_loaded = False
135
+ self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
136
+ self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
137
+
138
+ state = ControllerState()
139
+
140
+ # ===== Helper Functions =====
141
+ async def split_model_weights():
142
+ """Split model weights into chunks based on available servers"""
143
+ try:
144
+ import torch
145
+
146
+ # Load the full model weights
147
+ model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
148
+ weights = torch.load(model_file, map_location='cpu')
149
+
150
+ # Calculate chunks based on number of servers
151
+ total_params = sum(p.numel() for p in weights.values())
152
+ num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS)
153
+ params_per_chunk = Settings.get_optimal_chunk_size(total_params, num_servers)
154
+
155
+ print(f"[INFO] Total parameters: {total_params:,}")
156
+ print(f"[INFO] Available servers: {num_servers}")
157
+ print(f"[INFO] Parameters per chunk: {params_per_chunk:,}")
158
+
159
+ current_chunk = []
160
+ current_size = 0
161
+ chunk_id = 0
162
+
163
+ for key, tensor in weights.items():
164
+ tensor_size = tensor.numel()
165
+
166
+ if current_size + tensor_size > params_per_chunk and current_chunk:
167
+ # Save current chunk
168
+ chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
169
+ torch.save({k: weights[k] for k in current_chunk}, chunk_path)
170
+
171
+ # Create chunk metadata
172
+ state.model_chunks[chunk_id] = ModelChunk(
173
+ chunk_id=chunk_id,
174
+ files=[f"chunk_{chunk_id}.safetensors"],
175
+ config={
176
+ "weight_keys": current_chunk,
177
+ "input_size": weights[current_chunk[0]].size(1),
178
+ "output_size": weights[current_chunk[-1]].size(0)
179
+ }
180
+ )
181
+
182
+ # Reset for next chunk
183
+ current_chunk = []
184
+ current_size = 0
185
+ chunk_id += 1
186
+
187
+ current_chunk.append(key)
188
+ current_size += tensor_size
189
+
190
+ # Save last chunk if not empty
191
+ if current_chunk:
192
+ chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
193
+ torch.save({k: weights[k] for k in current_chunk}, chunk_path)
194
+
195
+ state.model_chunks[chunk_id] = ModelChunk(
196
+ chunk_id=chunk_id,
197
+ files=[f"chunk_{chunk_id}.safetensors"],
198
+ config={
199
+ "weight_keys": current_chunk,
200
+ "input_size": weights[current_chunk[0]].size(1),
201
+ "output_size": weights[current_chunk[-1]].size(0)
202
+ }
203
+ )
204
+
205
+ print(f"[INFO] Split model into {len(state.model_chunks)} chunks")
206
+ return True
207
+
208
+ except Exception as e:
209
+ print(f"[ERROR] Failed to split model weights: {str(e)}")
210
+ return False
211
+
212
+ async def distribute_model_chunks():
213
+ """Distribute model chunks across available tensor servers"""
214
+ try:
215
+ available_servers = [
216
+ server for server in state.tensor_servers.values()
217
+ if server.status in ["ready", "busy"] and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD
218
+ ]
219
+
220
+ min_required = Settings.get_min_servers_required()
221
+ if len(available_servers) < min_required:
222
+ raise Exception(f"Not enough healthy servers. Need {min_required}, got {len(available_servers)}")
223
+
224
+ # Create or update weight chunks based on current server count
225
+ if not state.model_chunks or len(state.model_chunks) > len(available_servers) * 3:
226
+ if not await split_model_weights():
227
+ raise Exception("Failed to split model weights")
228
+
229
+ # Prepare for parallel distribution
230
+ tasks = []
231
+ min_replicas = Settings.get_min_replica_count(len(available_servers))
232
+ chunks_per_server = len(state.model_chunks) / len(available_servers)
233
+ print(f"[INFO] Distributing chunks with min {min_replicas} replicas per chunk")
234
+ print(f"[INFO] Target chunks per server: {chunks_per_server:.1f}")
235
+
236
+ # Distribute chunks
237
+ for chunk_id, chunk in state.model_chunks.items():
238
+ # Calculate optimal number of replicas based on chunk size and server capacity
239
+ target_replicas = max(min_replicas,
240
+ int(chunks_per_server * len(available_servers) / len(state.model_chunks)))
241
+
242
+ current_assignments = set(chunk.server_assignments)
243
+ current_healthy = [url for url in current_assignments
244
+ if state.tensor_servers[url].status in ["ready", "busy"]]
245
+
246
+ # Remove unhealthy assignments
247
+ chunk.server_assignments = current_healthy
248
+
249
+ # Add new assignments if needed
250
+ while len(chunk.server_assignments) < target_replicas:
251
+ # Find least loaded eligible server
252
+ eligible_servers = [
253
+ server for server in available_servers
254
+ if str(server.url) not in chunk.server_assignments
255
+ and len(server.model_chunks) < (len(state.model_chunks) / len(available_servers) * 1.5)
256
+ ]
257
+
258
+ if not eligible_servers:
259
+ break
260
+
261
+ # Sort by load and error count
262
+ eligible_servers.sort(key=lambda s: (
263
+ len(s.model_chunks),
264
+ s.metrics.error_count,
265
+ s.metrics.cpu_usage
266
+ ))
267
+
268
+ # Assign to best server
269
+ best_server = eligible_servers[0]
270
+ chunk.server_assignments.append(str(best_server.url))
271
+ best_server.model_chunks.append(chunk_id)
272
+ print(f"[INFO] Assigned chunk {chunk_id} to server {best_server.url}")
273
+
274
+ return True
275
+
276
+ except Exception as e:
277
+ print(f"[ERROR] Failed to distribute model chunks: {str(e)}")
278
+ return False
279
+
280
+ async def monitor_tensor_servers():
281
+ """Periodically check health and update metrics of all tensor servers"""
282
+ while True:
283
+ for server_url, server in state.tensor_servers.items():
284
+ try:
285
+ # Check basic health
286
+ is_healthy = await check_tensor_server_health(server_url)
287
+
288
+ if not is_healthy:
289
+ server.status = "error"
290
+ server.metrics.error_count += 1
291
+ print(f"[WARN] Server {server_url} is unhealthy")
292
+ continue
293
+
294
+ # Get detailed metrics
295
+ async with aiohttp.ClientSession() as session:
296
+ async with session.get(f"{server_url}/metrics", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response:
297
+ if response.status == 200:
298
+ metrics = await response.json()
299
+ server.metrics = ServerMetrics(**metrics)
300
+
301
+ # Update server status based on metrics
302
+ if server.metrics.error_count > Settings.MAX_ERROR_THRESHOLD:
303
+ server.status = "degraded"
304
+ elif server.metrics.cpu_usage > 90 or server.metrics.memory_usage > 90:
305
+ server.status = "busy"
306
+ else:
307
+ server.status = "ready"
308
+
309
+ server.last_heartbeat = datetime.now()
310
+
311
+ except Exception as e:
312
+ print(f"[ERROR] Failed to monitor server {server_url}: {str(e)}")
313
+ server.status = "error"
314
+ server.metrics.last_error = str(e)
315
+ server.metrics.error_count += 1
316
+
317
+ # Check for servers that haven't responded in a while
318
+ current_time = datetime.now()
319
+ for server_url, server in state.tensor_servers.items():
320
+ if (current_time - server.last_heartbeat).seconds > Settings.SERVER_TIMEOUT:
321
+ print(f"[WARN] Server {server_url} hasn't responded in {Settings.SERVER_TIMEOUT} seconds")
322
+ server.status = "error"
323
+
324
+ await asyncio.sleep(Settings.MONITORING_INTERVAL)
325
+
326
+ def get_next_model_version(base_dir: str, model_name: str) -> int:
327
+ """Get the next available version number for the model"""
328
+ existing_versions = []
329
+ model_base_dir = os.path.join(base_dir, model_name)
330
+ if os.path.exists(model_base_dir):
331
+ for d in os.listdir(model_base_dir):
332
+ if d.startswith('v') and d[1:].isdigit():
333
+ existing_versions.append(int(d[1:]))
334
+ return max(existing_versions + [0]) + 1
335
+
336
+ def check_existing_model(model_path: str) -> bool:
337
+ """Check if a model exists and has required files"""
338
+ if not os.path.exists(model_path):
339
+ return False
340
+
341
+ # Check for essential files
342
+ required_files = ['config.json']
343
+ model_files = os.listdir(model_path)
344
+
345
+ # Check for any weight files
346
+ has_weights = any(f.endswith(('.bin', '.safetensors')) for f in model_files)
347
+
348
+ return all(f in model_files for f in required_files) and has_weights
349
+
350
+ async def download_model_files():
351
+ """Downloads the model files using git clone from Hugging Face repository"""
352
+ try:
353
+ print(f"[INFO] Processing model from {Settings.MODEL_REPO}...")
354
+
355
+ # Create models directory
356
+ models_dir = os.path.join(os.getcwd(), "models")
357
+ os.makedirs(models_dir, exist_ok=True)
358
+ print(f"[INFO] Models directory: {models_dir}")
359
+
360
+ # Get the model name from the repository URL
361
+ model_name = Settings.MODEL_REPO.split('/')[-1]
362
+
363
+ # Create versioned model directory
364
+ version = get_next_model_version(models_dir, model_name)
365
+ model_base_dir = os.path.join(models_dir, model_name)
366
+ model_version_dir = os.path.join(model_base_dir, f"v{version}")
367
+
368
+ # Check if previous version exists and is valid
369
+ if version > 1:
370
+ prev_version_dir = os.path.join(model_base_dir, f"v{version-1}")
371
+ if check_existing_model(prev_version_dir):
372
+ print(f"[INFO] Using existing model from {prev_version_dir}")
373
+ model_path = prev_version_dir
374
+ state.is_model_loaded = True
375
+ else:
376
+ # Clone new version if previous is invalid or incomplete
377
+ os.makedirs(model_version_dir, exist_ok=True)
378
+ success = clone_repository(Settings.MODEL_REPO, model_version_dir)
379
+ if not success:
380
+ raise Exception("Failed to clone repository")
381
+ model_path = model_version_dir
382
+ print(f"[INFO] Successfully cloned model to {model_path}")
383
+ else:
384
+ # First time download
385
+ os.makedirs(model_version_dir, exist_ok=True)
386
+ success = clone_repository(Settings.MODEL_REPO, model_version_dir)
387
+ if not success:
388
+ raise Exception("Failed to clone repository")
389
+ model_path = model_version_dir
390
+ print(f"[INFO] Successfully cloned model to {model_path}")
391
+
392
+ # Load and parse the config
393
+ config_path = os.path.join(model_path, "config.json")
394
+ if os.path.exists(config_path):
395
+ with open(config_path, 'r') as f:
396
+ state.model_config = json.load(f)
397
+ print("[INFO] Loaded model configuration")
398
+ print(f"[INFO] Model type: {state.model_config.get('model_type', 'unknown')}")
399
+ print(f"[INFO] Architecture: {state.model_config.get('architectures', ['unknown'])[0]}")
400
+ else:
401
+ print("[WARN] No config.json found in model directory")
402
+
403
+ # Scan for model files
404
+ print("[INFO] Scanning for model files...")
405
+ for root, _, files in os.walk(model_path):
406
+ for file in files:
407
+ if file.endswith(('.bin', '.json', '.safetensors')):
408
+ file_path = os.path.join(root, file)
409
+ state.model_files[file] = file_path
410
+ print(f"[INFO] Found model file: {file}")
411
+
412
+ if state.model_files:
413
+ state.is_model_loaded = True
414
+ print(f"[INFO] Model files found successfully! Total files: {len(state.model_files)}")
415
+ print(f"[INFO] Model location: {model_path}")
416
+ return True
417
+ else:
418
+ raise ValueError("No model files were found in the repository")
419
+
420
+ except Exception as e:
421
+ print(f"[ERROR] Failed to process model files: {e}")
422
+ state.is_model_loaded = False
423
+ raise
424
+
425
+ async def check_tensor_server_health(url: HttpUrl) -> bool:
426
+ """Checks if a tensor server is healthy"""
427
+ try:
428
+ async with aiohttp.ClientSession() as session:
429
+ async with session.get(f"{url}/health", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response:
430
+ return response.status == 200
431
+ except:
432
+ return False
433
+
434
+ # ===== API Endpoints =====
435
+ async def execute_tensor_operation(operation_id: str, server_url: HttpUrl, operation: str, data: Dict):
436
+ """Execute an operation on a tensor server and wait for results"""
437
+ try:
438
+ async with aiohttp.ClientSession() as session:
439
+ # Start the operation
440
+ async with session.post(
441
+ f"{server_url}/{operation}",
442
+ json=data,
443
+ timeout=Settings.TENSOR_SERVER_TIMEOUT
444
+ ) as response:
445
+ if response.status != 200:
446
+ error_msg = await response.text()
447
+ raise HTTPException(
448
+ status_code=response.status,
449
+ detail=f"Operation failed on server {server_url}: {error_msg}"
450
+ )
451
+
452
+ initial_response = await response.json()
453
+ if initial_response.get("status") == "completed":
454
+ # Operation completed immediately
455
+ state.operation_results[operation_id] = initial_response
456
+ return initial_response
457
+
458
+ # Operation is async, poll for results
459
+ while True:
460
+ await asyncio.sleep(1) # Poll interval
461
+ async with session.get(
462
+ f"{server_url}/operation/{initial_response['operation_id']}",
463
+ timeout=Settings.TENSOR_SERVER_TIMEOUT
464
+ ) as status_response:
465
+ if status_response.status != 200:
466
+ raise HTTPException(
467
+ status_code=status_response.status,
468
+ detail=f"Failed to get operation status from {server_url}"
469
+ )
470
+
471
+ status_data = await status_response.json()
472
+ if status_data["status"] in ["completed", "failed"]:
473
+ state.operation_results[operation_id] = status_data
474
+ if status_data["status"] == "failed":
475
+ raise HTTPException(
476
+ status_code=500,
477
+ detail=f"Operation failed on server {server_url}: {status_data.get('error')}"
478
+ )
479
+ return status_data
480
+
481
+ except asyncio.TimeoutError:
482
+ raise HTTPException(
483
+ status_code=504,
484
+ detail=f"Operation timed out on server {server_url}"
485
+ )
486
+ except Exception as e:
487
+ raise HTTPException(
488
+ status_code=500,
489
+ detail=f"Error executing operation on {server_url}: {str(e)}"
490
+ )
491
+
492
+ @app.post("/execute/{operation}")
493
+ async def execute_operation(operation: str, data: Dict):
494
+ """Execute an operation across tensor servers and collect results"""
495
+ operation_id = f"{operation}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{len(state.operation_results)}"
496
+
497
+ # Get available servers with required chunks
498
+ available_servers = [
499
+ server for server in state.tensor_servers.values()
500
+ if server.status in ["ready", "busy"]
501
+ and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD
502
+ ]
503
+
504
+ if not available_servers:
505
+ raise HTTPException(
506
+ status_code=503,
507
+ detail="No available tensor servers"
508
+ )
509
+
510
+ # Start operations on all relevant servers in parallel
511
+ tasks = []
512
+ for server in available_servers:
513
+ if operation in ["compute", "forward"]:
514
+ # For compute operations, only use servers with required chunks
515
+ required_chunks = data.get("required_chunks", [])
516
+ if not all(chunk_id in server.model_chunks for chunk_id in required_chunks):
517
+ continue
518
+
519
+ task = asyncio.create_task(
520
+ execute_tensor_operation(
521
+ f"{operation_id}_{server.url}",
522
+ server.url,
523
+ operation,
524
+ data
525
+ )
526
+ )
527
+ tasks.append(task)
528
+ state.pending_operations[f"{operation_id}_{server.url}"] = task
529
+
530
+ if not tasks:
531
+ raise HTTPException(
532
+ status_code=400,
533
+ detail="No servers available with required model chunks"
534
+ )
535
+
536
+ try:
537
+ # Wait for all operations to complete
538
+ results = await asyncio.gather(*tasks)
539
+
540
+ # Process and aggregate results
541
+ aggregated_result = {
542
+ "operation_id": operation_id,
543
+ "status": "completed",
544
+ "server_results": results,
545
+ "timestamp": datetime.now().isoformat()
546
+ }
547
+
548
+ # Clean up
549
+ for task_id in list(state.pending_operations.keys()):
550
+ if task_id.startswith(operation_id):
551
+ del state.pending_operations[task_id]
552
+
553
+ return aggregated_result
554
+
555
+ except Exception as e:
556
+ # Cancel any remaining tasks
557
+ for task in tasks:
558
+ if not task.done():
559
+ task.cancel()
560
+
561
+ # Clean up
562
+ for task_id in list(state.pending_operations.keys()):
563
+ if task_id.startswith(operation_id):
564
+ del state.pending_operations[task_id]
565
+
566
+ raise HTTPException(
567
+ status_code=500,
568
+ detail=f"Operation failed: {str(e)}"
569
+ )
570
+
571
+ @app.get("/operation/{operation_id}")
572
+ async def get_operation_status(operation_id: str):
573
+ """Get the status of an operation"""
574
+ # Check completed operations
575
+ results = {
576
+ k: v for k, v in state.operation_results.items()
577
+ if k.startswith(operation_id)
578
+ }
579
+
580
+ if results:
581
+ return {
582
+ "operation_id": operation_id,
583
+ "status": "completed",
584
+ "results": results
585
+ }
586
+
587
+ # Check pending operations
588
+ pending = {
589
+ k: "running" for k in state.pending_operations.keys()
590
+ if k.startswith(operation_id)
591
+ }
592
+
593
+ if pending:
594
+ return {
595
+ "operation_id": operation_id,
596
+ "status": "running",
597
+ "pending_servers": list(pending.keys())
598
+ }
599
+
600
+ raise HTTPException(
601
+ status_code=404,
602
+ detail=f"Operation {operation_id} not found"
603
+ )
604
+
605
+ @app.get("/")
606
+ async def root():
607
+ """Health check endpoint"""
608
+ return {
609
+ "status": "running",
610
+ "model_loaded": state.is_model_loaded,
611
+ "registered_servers": len(state.tensor_servers),
612
+ "downloaded_files": len(state.model_files),
613
+ "config_loaded": bool(state.model_config)
614
+ }
615
+
616
+ @app.get("/health")
617
+ async def health_check():
618
+ """Detailed health check"""
619
+ return {
620
+ "status": "healthy",
621
+ "model_loaded": state.is_model_loaded,
622
+ "registered_servers": len(state.tensor_servers),
623
+ "downloaded_files": list(state.model_files.keys()),
624
+ "config_loaded": bool(state.model_config),
625
+ "model_type": state.model_config.get("model_type", "unknown")
626
+ }
627
+
628
+ @app.post("/register_tensor_server")
629
+ async def register_tensor_server(server_url: HttpUrl):
630
+ """Register a new tensor server"""
631
+ if not await check_tensor_server_health(server_url):
632
+ raise HTTPException(status_code=400, detail="Tensor server is not healthy")
633
+
634
+ state.tensor_servers[str(server_url)] = TensorServer(url=server_url)
635
+ print(f"[INFO] Registered new tensor server at {server_url}")
636
+
637
+ return {
638
+ "status": "registered",
639
+ "registered_servers": len(state.tensor_servers),
640
+ "server_id": str(server_url)
641
+ }
642
+
643
+ @app.delete("/unregister_tensor_server")
644
+ async def unregister_tensor_server(server_url: HttpUrl):
645
+ """Unregister a tensor server"""
646
+ if str(server_url) in state.tensor_servers:
647
+ # Remove server assignments from chunks
648
+ for chunk in state.model_chunks.values():
649
+ if str(server_url) in chunk.server_assignments:
650
+ chunk.server_assignments.remove(str(server_url))
651
+
652
+ del state.tensor_servers[str(server_url)]
653
+ print(f"[INFO] Unregistered tensor server at {server_url}")
654
+
655
+ # Trigger redistribution of chunks
656
+ await distribute_model_chunks()
657
+ return {"status": "unregistered"}
658
+ raise HTTPException(status_code=404, detail="Server not found")
659
+
660
+ @app.get("/server/{server_url}/chunks")
661
+ async def get_server_chunks(server_url: HttpUrl):
662
+ """Get the chunks assigned to a specific server"""
663
+ if str(server_url) not in state.tensor_servers:
664
+ raise HTTPException(status_code=404, detail="Server not found")
665
+
666
+ server = state.tensor_servers[str(server_url)]
667
+ assigned_chunks = [
668
+ state.model_chunks[chunk_id]
669
+ for chunk_id in server.model_chunks
670
+ ]
671
+
672
+ return {
673
+ "server_status": server.status,
674
+ "assigned_chunks": assigned_chunks,
675
+ "metrics": server.metrics.dict()
676
+ }
677
+
678
+ @app.post("/redistribute")
679
+ async def redistribute_chunks():
680
+ """Manually trigger redistribution of model chunks"""
681
+ success = await distribute_model_chunks()
682
+ if not success:
683
+ raise HTTPException(status_code=500, detail="Failed to redistribute chunks")
684
+
685
+ return {
686
+ "status": "redistributed",
687
+ "chunk_assignments": {
688
+ chunk_id: chunk.server_assignments
689
+ for chunk_id, chunk in state.model_chunks.items()
690
+ }
691
+ }
692
+
693
+ @app.get("/chunks/{chunk_id}/status")
694
+ async def get_chunk_status(chunk_id: int):
695
+ """Get the status and assignments of a specific chunk"""
696
+ if chunk_id not in state.model_chunks:
697
+ raise HTTPException(status_code=404, detail="Chunk not found")
698
+
699
+ chunk = state.model_chunks[chunk_id]
700
+ return {
701
+ "chunk_id": chunk_id,
702
+ "status": chunk.status,
703
+ "server_assignments": chunk.server_assignments,
704
+ "metrics": chunk.metrics
705
+ }
706
+
707
+ @app.post("/initialize")
708
+ async def initialize_system():
709
+ """Download model files and prepare for distribution"""
710
+ await download_model_files()
711
+
712
+ # Verify downloaded files
713
+ files_status = {}
714
+ total_size = 0
715
+ for filename, filepath in state.model_files.items():
716
+ exists = os.path.exists(filepath)
717
+ if exists:
718
+ size = os.path.getsize(filepath)
719
+ total_size += size
720
+ files_status[filename] = {"exists": exists, "size_bytes": size}
721
+ else:
722
+ files_status[filename] = {"exists": exists, "size_bytes": 0}
723
+
724
+ return {
725
+ "status": "initialized",
726
+ "model_loaded": state.is_model_loaded,
727
+ "files_status": files_status,
728
+ "total_size_bytes": total_size,
729
+ "config_loaded": bool(state.model_config),
730
+ "model_type": state.model_config.get("model_type", "unknown"),
731
+ "architecture": state.model_config.get("architectures", ["unknown"])[0]
732
+ }
733
+
734
+ # ===== Main Execution =====
735
+ @app.on_event("startup")
736
+ async def startup_event():
737
+ """Initialize the server and start background tasks"""
738
+ print("[INFO] Initializing system...")
739
+ await initialize_system()
740
+ print("[INFO] Model initialization complete")
741
+
742
+ # Start monitoring task
743
+ asyncio.create_task(monitor_tensor_servers())
744
+ print("[INFO] Server monitoring started")
745
+
746
+ if __name__ == "__main__":
747
+ port = int(os.getenv("PORT", 8000))
748
+ print(f"[INFO] Starting controller server on port {port}")
749
+ print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
750
+
751
+ uvicorn.run(
752
+ "controller_server_new:app",
753
+ host="0.0.0.0",
754
+ port=port,
755
+ reload=False
756
  )