Update app.py
Browse files
app.py
CHANGED
|
@@ -136,39 +136,26 @@ class ControllerState:
|
|
| 136 |
self.chunks_dir: str = "" # Directory containing chunk files
|
| 137 |
self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
|
| 138 |
self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
|
| 139 |
-
self.chunk_assignments: Dict[int, List[Dict[str, any]]] = {} # Track which chunks are on which servers
|
| 140 |
-
self.chunk_distribution_history: List[Dict[str, any]] = [] # Track distribution history with timestamps
|
| 141 |
|
| 142 |
state = ControllerState()
|
| 143 |
|
| 144 |
# ===== Helper Functions =====
|
| 145 |
async def split_model_weights():
|
| 146 |
-
"""Split model files into chunks
|
| 147 |
try:
|
| 148 |
import os
|
| 149 |
import math
|
| 150 |
import shutil
|
| 151 |
-
import torch
|
| 152 |
-
from safetensors.torch import save_file, load_file
|
| 153 |
from pathlib import Path
|
| 154 |
|
| 155 |
-
# Find model file
|
| 156 |
try:
|
| 157 |
-
model_file = next(f for f in state.model_files.values() if f.endswith('.
|
| 158 |
-
print(f"[INFO] Found
|
| 159 |
-
|
| 160 |
-
# Convert to safetensors
|
| 161 |
-
print("[INFO] Converting model to safetensors format...")
|
| 162 |
-
weights = torch.load(model_file, map_location='cpu')
|
| 163 |
-
safetensors_path = os.path.join(state.model_path, "model.safetensors")
|
| 164 |
-
save_file(weights, safetensors_path)
|
| 165 |
-
model_file = safetensors_path
|
| 166 |
-
print(f"[INFO] Converted model to safetensors format: {model_file}")
|
| 167 |
-
|
| 168 |
except StopIteration:
|
| 169 |
try:
|
| 170 |
-
model_file = next(f for f in state.model_files.values() if f.endswith('.
|
| 171 |
-
print(f"[INFO] Found
|
| 172 |
except StopIteration:
|
| 173 |
raise Exception("No model weight files found")
|
| 174 |
|
|
@@ -444,31 +431,7 @@ async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict)
|
|
| 444 |
raise Exception(f"Failed to upload chunk data: {error_msg}")
|
| 445 |
|
| 446 |
upload_result = await upload_response.json()
|
| 447 |
-
|
| 448 |
-
# Track the assignment
|
| 449 |
-
if chunk_id not in state.chunk_assignments:
|
| 450 |
-
state.chunk_assignments[chunk_id] = []
|
| 451 |
-
|
| 452 |
-
assignment = {
|
| 453 |
-
"server_url": server_url,
|
| 454 |
-
"timestamp": datetime.now().isoformat(),
|
| 455 |
-
"status": "loaded",
|
| 456 |
-
"size_bytes": upload_result.get('size_bytes', 0)
|
| 457 |
-
}
|
| 458 |
-
state.chunk_assignments[chunk_id].append(assignment)
|
| 459 |
-
|
| 460 |
-
# Add to history
|
| 461 |
-
state.chunk_distribution_history.append({
|
| 462 |
-
"chunk_id": chunk_id,
|
| 463 |
-
"server_url": server_url,
|
| 464 |
-
"timestamp": datetime.now().isoformat(),
|
| 465 |
-
"action": "upload",
|
| 466 |
-
"status": "success",
|
| 467 |
-
"size_bytes": upload_result.get('size_bytes', 0)
|
| 468 |
-
})
|
| 469 |
-
|
| 470 |
print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)")
|
| 471 |
-
print(f"[INFO] Current assignments for chunk {chunk_id}: {len(state.chunk_assignments[chunk_id])} servers")
|
| 472 |
return True
|
| 473 |
|
| 474 |
except Exception as e:
|
|
@@ -984,66 +947,16 @@ async def redistribute_chunks():
|
|
| 984 |
|
| 985 |
@app.get("/chunks/{chunk_id}/status")
|
| 986 |
async def get_chunk_status(chunk_id: int):
|
| 987 |
-
"""Get
|
| 988 |
if chunk_id not in state.model_chunks:
|
| 989 |
raise HTTPException(status_code=404, detail="Chunk not found")
|
| 990 |
|
| 991 |
chunk = state.model_chunks[chunk_id]
|
| 992 |
-
assignments = state.chunk_assignments.get(chunk_id, [])
|
| 993 |
-
|
| 994 |
-
# Get current server status for each assignment
|
| 995 |
-
current_status = []
|
| 996 |
-
for assignment in assignments:
|
| 997 |
-
server_url = assignment["server_url"]
|
| 998 |
-
if server_url in state.tensor_servers:
|
| 999 |
-
server = state.tensor_servers[server_url]
|
| 1000 |
-
current_status.append({
|
| 1001 |
-
"server_url": server_url,
|
| 1002 |
-
"server_status": server.status,
|
| 1003 |
-
"last_heartbeat": server.last_heartbeat.isoformat(),
|
| 1004 |
-
"metrics": server.metrics.dict(),
|
| 1005 |
-
"assignment_time": assignment["timestamp"]
|
| 1006 |
-
})
|
| 1007 |
-
|
| 1008 |
return {
|
| 1009 |
"chunk_id": chunk_id,
|
| 1010 |
"status": chunk.status,
|
| 1011 |
-
"
|
| 1012 |
-
"
|
| 1013 |
-
"assignment_history": [
|
| 1014 |
-
h for h in state.chunk_distribution_history
|
| 1015 |
-
if h["chunk_id"] == chunk_id
|
| 1016 |
-
],
|
| 1017 |
-
"metrics": chunk.metrics,
|
| 1018 |
-
"config": chunk.config
|
| 1019 |
-
}
|
| 1020 |
-
|
| 1021 |
-
@app.get("/distribution/status")
|
| 1022 |
-
async def get_distribution_status():
|
| 1023 |
-
"""Get overall distribution status of all chunks"""
|
| 1024 |
-
distribution_summary = {}
|
| 1025 |
-
|
| 1026 |
-
for chunk_id, chunk in state.model_chunks.items():
|
| 1027 |
-
assignments = state.chunk_assignments.get(chunk_id, [])
|
| 1028 |
-
active_servers = [
|
| 1029 |
-
a["server_url"] for a in assignments
|
| 1030 |
-
if a["server_url"] in state.tensor_servers and
|
| 1031 |
-
state.tensor_servers[a["server_url"]].status in ["ready", "busy"]
|
| 1032 |
-
]
|
| 1033 |
-
|
| 1034 |
-
distribution_summary[chunk_id] = {
|
| 1035 |
-
"total_assignments": len(assignments),
|
| 1036 |
-
"active_servers": len(active_servers),
|
| 1037 |
-
"server_urls": active_servers,
|
| 1038 |
-
"size_bytes": chunk.size_bytes,
|
| 1039 |
-
"status": chunk.status
|
| 1040 |
-
}
|
| 1041 |
-
|
| 1042 |
-
return {
|
| 1043 |
-
"total_chunks": len(state.model_chunks),
|
| 1044 |
-
"total_servers": len(state.tensor_servers),
|
| 1045 |
-
"chunks": distribution_summary,
|
| 1046 |
-
"history": state.chunk_distribution_history[-10:] # Last 10 events
|
| 1047 |
}
|
| 1048 |
|
| 1049 |
@app.post("/initialize")
|
|
|
|
| 136 |
self.chunks_dir: str = "" # Directory containing chunk files
|
| 137 |
self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
|
| 138 |
self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
|
|
|
|
|
|
|
| 139 |
|
| 140 |
state = ControllerState()
|
| 141 |
|
| 142 |
# ===== Helper Functions =====
|
| 143 |
async def split_model_weights():
|
| 144 |
+
"""Split model files into chunks based on available servers without loading into memory"""
|
| 145 |
try:
|
| 146 |
import os
|
| 147 |
import math
|
| 148 |
import shutil
|
|
|
|
|
|
|
| 149 |
from pathlib import Path
|
| 150 |
|
| 151 |
+
# Find model file (safetensors or pytorch)
|
| 152 |
try:
|
| 153 |
+
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
|
| 154 |
+
print(f"[INFO] Found safetensors file: {model_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
except StopIteration:
|
| 156 |
try:
|
| 157 |
+
model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
|
| 158 |
+
print(f"[INFO] Found PyTorch file: {model_file}")
|
| 159 |
except StopIteration:
|
| 160 |
raise Exception("No model weight files found")
|
| 161 |
|
|
|
|
| 431 |
raise Exception(f"Failed to upload chunk data: {error_msg}")
|
| 432 |
|
| 433 |
upload_result = await upload_response.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)")
|
|
|
|
| 435 |
return True
|
| 436 |
|
| 437 |
except Exception as e:
|
|
|
|
| 947 |
|
| 948 |
@app.get("/chunks/{chunk_id}/status")
|
| 949 |
async def get_chunk_status(chunk_id: int):
|
| 950 |
+
"""Get the status and assignments of a specific chunk"""
|
| 951 |
if chunk_id not in state.model_chunks:
|
| 952 |
raise HTTPException(status_code=404, detail="Chunk not found")
|
| 953 |
|
| 954 |
chunk = state.model_chunks[chunk_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
return {
|
| 956 |
"chunk_id": chunk_id,
|
| 957 |
"status": chunk.status,
|
| 958 |
+
"server_assignments": chunk.server_assignments,
|
| 959 |
+
"metrics": chunk.metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
}
|
| 961 |
|
| 962 |
@app.post("/initialize")
|