Update app.py
Browse files
app.py
CHANGED
|
@@ -136,26 +136,39 @@ class ControllerState:
|
|
| 136 |
self.chunks_dir: str = "" # Directory containing chunk files
|
| 137 |
self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
|
| 138 |
self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
|
|
|
|
|
|
|
| 139 |
|
| 140 |
state = ControllerState()
|
| 141 |
|
| 142 |
# ===== Helper Functions =====
|
| 143 |
async def split_model_weights():
|
| 144 |
-
"""Split model files into chunks
|
| 145 |
try:
|
| 146 |
import os
|
| 147 |
import math
|
| 148 |
import shutil
|
|
|
|
|
|
|
| 149 |
from pathlib import Path
|
| 150 |
|
| 151 |
-
# Find model file
|
| 152 |
try:
|
| 153 |
-
model_file = next(f for f in state.model_files.values() if f.endswith('.
|
| 154 |
-
print(f"[INFO] Found
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
except StopIteration:
|
| 156 |
try:
|
| 157 |
-
model_file = next(f for f in state.model_files.values() if f.endswith('.
|
| 158 |
-
print(f"[INFO] Found
|
| 159 |
except StopIteration:
|
| 160 |
raise Exception("No model weight files found")
|
| 161 |
|
|
@@ -397,6 +410,7 @@ async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict)
|
|
| 397 |
}
|
| 398 |
|
| 399 |
async with aiohttp.ClientSession() as session:
|
|
|
|
| 400 |
async with session.post(
|
| 401 |
f"{server_url}/load_chunk",
|
| 402 |
json=chunk_data,
|
|
@@ -404,11 +418,58 @@ async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict)
|
|
| 404 |
) as response:
|
| 405 |
if response.status != 200:
|
| 406 |
error_msg = await response.text()
|
| 407 |
-
raise Exception(f"Failed to
|
| 408 |
-
|
| 409 |
result = await response.json()
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
except Exception as e:
|
| 414 |
print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}")
|
|
@@ -923,16 +984,66 @@ async def redistribute_chunks():
|
|
| 923 |
|
| 924 |
@app.get("/chunks/{chunk_id}/status")
|
| 925 |
async def get_chunk_status(chunk_id: int):
|
| 926 |
-
"""Get
|
| 927 |
if chunk_id not in state.model_chunks:
|
| 928 |
raise HTTPException(status_code=404, detail="Chunk not found")
|
| 929 |
|
| 930 |
chunk = state.model_chunks[chunk_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
return {
|
| 932 |
"chunk_id": chunk_id,
|
| 933 |
"status": chunk.status,
|
| 934 |
-
"
|
| 935 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 936 |
}
|
| 937 |
|
| 938 |
@app.post("/initialize")
|
|
@@ -1042,7 +1153,7 @@ if __name__ == "__main__":
|
|
| 1042 |
print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
|
| 1043 |
|
| 1044 |
uvicorn.run(
|
| 1045 |
-
"
|
| 1046 |
host="0.0.0.0",
|
| 1047 |
port=port,
|
| 1048 |
reload=False
|
|
|
|
| 136 |
self.chunks_dir: str = "" # Directory containing chunk files
|
| 137 |
self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
|
| 138 |
self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
|
| 139 |
+
self.chunk_assignments: Dict[int, List[Dict[str, any]]] = {} # Track which chunks are on which servers
|
| 140 |
+
self.chunk_distribution_history: List[Dict[str, any]] = [] # Track distribution history with timestamps
|
| 141 |
|
| 142 |
state = ControllerState()
|
| 143 |
|
| 144 |
# ===== Helper Functions =====
|
| 145 |
async def split_model_weights():
|
| 146 |
+
"""Split model files into chunks and convert to safetensors format"""
|
| 147 |
try:
|
| 148 |
import os
|
| 149 |
import math
|
| 150 |
import shutil
|
| 151 |
+
import torch
|
| 152 |
+
from safetensors.torch import save_file, load_file
|
| 153 |
from pathlib import Path
|
| 154 |
|
| 155 |
+
# Find model file and convert to safetensors if needed
|
| 156 |
try:
|
| 157 |
+
model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
|
| 158 |
+
print(f"[INFO] Found PyTorch file: {model_file}")
|
| 159 |
+
|
| 160 |
+
# Convert to safetensors
|
| 161 |
+
print("[INFO] Converting model to safetensors format...")
|
| 162 |
+
weights = torch.load(model_file, map_location='cpu')
|
| 163 |
+
safetensors_path = os.path.join(state.model_path, "model.safetensors")
|
| 164 |
+
save_file(weights, safetensors_path)
|
| 165 |
+
model_file = safetensors_path
|
| 166 |
+
print(f"[INFO] Converted model to safetensors format: {model_file}")
|
| 167 |
+
|
| 168 |
except StopIteration:
|
| 169 |
try:
|
| 170 |
+
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
|
| 171 |
+
print(f"[INFO] Found existing safetensors file: {model_file}")
|
| 172 |
except StopIteration:
|
| 173 |
raise Exception("No model weight files found")
|
| 174 |
|
|
|
|
| 410 |
}
|
| 411 |
|
| 412 |
async with aiohttp.ClientSession() as session:
|
| 413 |
+
# Step 1: Send chunk configuration
|
| 414 |
async with session.post(
|
| 415 |
f"{server_url}/load_chunk",
|
| 416 |
json=chunk_data,
|
|
|
|
| 418 |
) as response:
|
| 419 |
if response.status != 200:
|
| 420 |
error_msg = await response.text()
|
| 421 |
+
raise Exception(f"Failed to register chunk: {error_msg}")
|
| 422 |
+
|
| 423 |
result = await response.json()
|
| 424 |
+
if not result.get("ready_for_data", False):
|
| 425 |
+
raise Exception("Server not ready for chunk data")
|
| 426 |
+
|
| 427 |
+
# Step 2: Upload chunk data
|
| 428 |
+
with open(chunk_path, 'rb') as f:
|
| 429 |
+
chunk_file = f.read()
|
| 430 |
+
|
| 431 |
+
form = aiohttp.FormData()
|
| 432 |
+
form.add_field('file',
|
| 433 |
+
chunk_file,
|
| 434 |
+
filename=os.path.basename(chunk_path),
|
| 435 |
+
content_type='application/octet-stream')
|
| 436 |
+
|
| 437 |
+
async with session.post(
|
| 438 |
+
f"{server_url}/upload_chunk_data/{chunk_id}",
|
| 439 |
+
data=form,
|
| 440 |
+
timeout=Settings.TENSOR_SERVER_TIMEOUT
|
| 441 |
+
) as upload_response:
|
| 442 |
+
if upload_response.status != 200:
|
| 443 |
+
error_msg = await upload_response.text()
|
| 444 |
+
raise Exception(f"Failed to upload chunk data: {error_msg}")
|
| 445 |
+
|
| 446 |
+
upload_result = await upload_response.json()
|
| 447 |
+
|
| 448 |
+
# Track the assignment
|
| 449 |
+
if chunk_id not in state.chunk_assignments:
|
| 450 |
+
state.chunk_assignments[chunk_id] = []
|
| 451 |
+
|
| 452 |
+
assignment = {
|
| 453 |
+
"server_url": server_url,
|
| 454 |
+
"timestamp": datetime.now().isoformat(),
|
| 455 |
+
"status": "loaded",
|
| 456 |
+
"size_bytes": upload_result.get('size_bytes', 0)
|
| 457 |
+
}
|
| 458 |
+
state.chunk_assignments[chunk_id].append(assignment)
|
| 459 |
+
|
| 460 |
+
# Add to history
|
| 461 |
+
state.chunk_distribution_history.append({
|
| 462 |
+
"chunk_id": chunk_id,
|
| 463 |
+
"server_url": server_url,
|
| 464 |
+
"timestamp": datetime.now().isoformat(),
|
| 465 |
+
"action": "upload",
|
| 466 |
+
"status": "success",
|
| 467 |
+
"size_bytes": upload_result.get('size_bytes', 0)
|
| 468 |
+
})
|
| 469 |
+
|
| 470 |
+
print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)")
|
| 471 |
+
print(f"[INFO] Current assignments for chunk {chunk_id}: {len(state.chunk_assignments[chunk_id])} servers")
|
| 472 |
+
return True
|
| 473 |
|
| 474 |
except Exception as e:
|
| 475 |
print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}")
|
|
|
|
| 984 |
|
| 985 |
@app.get("/chunks/{chunk_id}/status")
|
| 986 |
async def get_chunk_status(chunk_id: int):
|
| 987 |
+
"""Get detailed status and assignments of a specific chunk"""
|
| 988 |
if chunk_id not in state.model_chunks:
|
| 989 |
raise HTTPException(status_code=404, detail="Chunk not found")
|
| 990 |
|
| 991 |
chunk = state.model_chunks[chunk_id]
|
| 992 |
+
assignments = state.chunk_assignments.get(chunk_id, [])
|
| 993 |
+
|
| 994 |
+
# Get current server status for each assignment
|
| 995 |
+
current_status = []
|
| 996 |
+
for assignment in assignments:
|
| 997 |
+
server_url = assignment["server_url"]
|
| 998 |
+
if server_url in state.tensor_servers:
|
| 999 |
+
server = state.tensor_servers[server_url]
|
| 1000 |
+
current_status.append({
|
| 1001 |
+
"server_url": server_url,
|
| 1002 |
+
"server_status": server.status,
|
| 1003 |
+
"last_heartbeat": server.last_heartbeat.isoformat(),
|
| 1004 |
+
"metrics": server.metrics.dict(),
|
| 1005 |
+
"assignment_time": assignment["timestamp"]
|
| 1006 |
+
})
|
| 1007 |
+
|
| 1008 |
return {
|
| 1009 |
"chunk_id": chunk_id,
|
| 1010 |
"status": chunk.status,
|
| 1011 |
+
"size_bytes": chunk.size_bytes,
|
| 1012 |
+
"current_assignments": current_status,
|
| 1013 |
+
"assignment_history": [
|
| 1014 |
+
h for h in state.chunk_distribution_history
|
| 1015 |
+
if h["chunk_id"] == chunk_id
|
| 1016 |
+
],
|
| 1017 |
+
"metrics": chunk.metrics,
|
| 1018 |
+
"config": chunk.config
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
@app.get("/distribution/status")
|
| 1022 |
+
async def get_distribution_status():
|
| 1023 |
+
"""Get overall distribution status of all chunks"""
|
| 1024 |
+
distribution_summary = {}
|
| 1025 |
+
|
| 1026 |
+
for chunk_id, chunk in state.model_chunks.items():
|
| 1027 |
+
assignments = state.chunk_assignments.get(chunk_id, [])
|
| 1028 |
+
active_servers = [
|
| 1029 |
+
a["server_url"] for a in assignments
|
| 1030 |
+
if a["server_url"] in state.tensor_servers and
|
| 1031 |
+
state.tensor_servers[a["server_url"]].status in ["ready", "busy"]
|
| 1032 |
+
]
|
| 1033 |
+
|
| 1034 |
+
distribution_summary[chunk_id] = {
|
| 1035 |
+
"total_assignments": len(assignments),
|
| 1036 |
+
"active_servers": len(active_servers),
|
| 1037 |
+
"server_urls": active_servers,
|
| 1038 |
+
"size_bytes": chunk.size_bytes,
|
| 1039 |
+
"status": chunk.status
|
| 1040 |
+
}
|
| 1041 |
+
|
| 1042 |
+
return {
|
| 1043 |
+
"total_chunks": len(state.model_chunks),
|
| 1044 |
+
"total_servers": len(state.tensor_servers),
|
| 1045 |
+
"chunks": distribution_summary,
|
| 1046 |
+
"history": state.chunk_distribution_history[-10:] # Last 10 events
|
| 1047 |
}
|
| 1048 |
|
| 1049 |
@app.post("/initialize")
|
|
|
|
| 1153 |
print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
|
| 1154 |
|
| 1155 |
uvicorn.run(
|
| 1156 |
+
"controller_server_new:app",
|
| 1157 |
host="0.0.0.0",
|
| 1158 |
port=port,
|
| 1159 |
reload=False
|