Fred808 commited on
Commit
f3a8698
·
verified ·
1 Parent(s): 336d73a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -14
app.py CHANGED
@@ -136,26 +136,39 @@ class ControllerState:
136
  self.chunks_dir: str = "" # Directory containing chunk files
137
  self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
138
  self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
 
 
139
 
140
  state = ControllerState()
141
 
142
  # ===== Helper Functions =====
143
  async def split_model_weights():
144
- """Split model files into chunks based on available servers without loading into memory"""
145
  try:
146
  import os
147
  import math
148
  import shutil
 
 
149
  from pathlib import Path
150
 
151
- # Find model file (safetensors or pytorch)
152
  try:
153
- model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
154
- print(f"[INFO] Found safetensors file: {model_file}")
 
 
 
 
 
 
 
 
 
155
  except StopIteration:
156
  try:
157
- model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
158
- print(f"[INFO] Found PyTorch file: {model_file}")
159
  except StopIteration:
160
  raise Exception("No model weight files found")
161
 
@@ -397,6 +410,7 @@ async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict)
397
  }
398
 
399
  async with aiohttp.ClientSession() as session:
 
400
  async with session.post(
401
  f"{server_url}/load_chunk",
402
  json=chunk_data,
@@ -404,11 +418,58 @@ async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict)
404
  ) as response:
405
  if response.status != 200:
406
  error_msg = await response.text()
407
- raise Exception(f"Failed to load chunk: {error_msg}")
408
-
409
  result = await response.json()
410
- print(f"[INFO] Successfully loaded chunk {chunk_id} to {server_url}")
411
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  except Exception as e:
414
  print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}")
@@ -923,16 +984,66 @@ async def redistribute_chunks():
923
 
924
  @app.get("/chunks/{chunk_id}/status")
925
  async def get_chunk_status(chunk_id: int):
926
- """Get the status and assignments of a specific chunk"""
927
  if chunk_id not in state.model_chunks:
928
  raise HTTPException(status_code=404, detail="Chunk not found")
929
 
930
  chunk = state.model_chunks[chunk_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
  return {
932
  "chunk_id": chunk_id,
933
  "status": chunk.status,
934
- "server_assignments": chunk.server_assignments,
935
- "metrics": chunk.metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936
  }
937
 
938
  @app.post("/initialize")
@@ -1042,7 +1153,7 @@ if __name__ == "__main__":
1042
  print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
1043
 
1044
  uvicorn.run(
1045
- "app:app",
1046
  host="0.0.0.0",
1047
  port=port,
1048
  reload=False
 
136
  self.chunks_dir: str = "" # Directory containing chunk files
137
  self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
138
  self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
139
+ self.chunk_assignments: Dict[int, List[Dict[str, any]]] = {} # Track which chunks are on which servers
140
+ self.chunk_distribution_history: List[Dict[str, any]] = [] # Track distribution history with timestamps
141
 
142
  state = ControllerState()
143
 
144
  # ===== Helper Functions =====
145
  async def split_model_weights():
146
+ """Split model files into chunks and convert to safetensors format"""
147
  try:
148
  import os
149
  import math
150
  import shutil
151
+ import torch
152
+ from safetensors.torch import save_file, load_file
153
  from pathlib import Path
154
 
155
+ # Find model file and convert to safetensors if needed
156
  try:
157
+ model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
158
+ print(f"[INFO] Found PyTorch file: {model_file}")
159
+
160
+ # Convert to safetensors
161
+ print("[INFO] Converting model to safetensors format...")
162
+ weights = torch.load(model_file, map_location='cpu')
163
+ safetensors_path = os.path.join(state.model_path, "model.safetensors")
164
+ save_file(weights, safetensors_path)
165
+ model_file = safetensors_path
166
+ print(f"[INFO] Converted model to safetensors format: {model_file}")
167
+
168
  except StopIteration:
169
  try:
170
+ model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
171
+ print(f"[INFO] Found existing safetensors file: {model_file}")
172
  except StopIteration:
173
  raise Exception("No model weight files found")
174
 
 
410
  }
411
 
412
  async with aiohttp.ClientSession() as session:
413
+ # Step 1: Send chunk configuration
414
  async with session.post(
415
  f"{server_url}/load_chunk",
416
  json=chunk_data,
 
418
  ) as response:
419
  if response.status != 200:
420
  error_msg = await response.text()
421
+ raise Exception(f"Failed to register chunk: {error_msg}")
422
+
423
  result = await response.json()
424
+ if not result.get("ready_for_data", False):
425
+ raise Exception("Server not ready for chunk data")
426
+
427
+ # Step 2: Upload chunk data
428
+ with open(chunk_path, 'rb') as f:
429
+ chunk_file = f.read()
430
+
431
+ form = aiohttp.FormData()
432
+ form.add_field('file',
433
+ chunk_file,
434
+ filename=os.path.basename(chunk_path),
435
+ content_type='application/octet-stream')
436
+
437
+ async with session.post(
438
+ f"{server_url}/upload_chunk_data/{chunk_id}",
439
+ data=form,
440
+ timeout=Settings.TENSOR_SERVER_TIMEOUT
441
+ ) as upload_response:
442
+ if upload_response.status != 200:
443
+ error_msg = await upload_response.text()
444
+ raise Exception(f"Failed to upload chunk data: {error_msg}")
445
+
446
+ upload_result = await upload_response.json()
447
+
448
+ # Track the assignment
449
+ if chunk_id not in state.chunk_assignments:
450
+ state.chunk_assignments[chunk_id] = []
451
+
452
+ assignment = {
453
+ "server_url": server_url,
454
+ "timestamp": datetime.now().isoformat(),
455
+ "status": "loaded",
456
+ "size_bytes": upload_result.get('size_bytes', 0)
457
+ }
458
+ state.chunk_assignments[chunk_id].append(assignment)
459
+
460
+ # Add to history
461
+ state.chunk_distribution_history.append({
462
+ "chunk_id": chunk_id,
463
+ "server_url": server_url,
464
+ "timestamp": datetime.now().isoformat(),
465
+ "action": "upload",
466
+ "status": "success",
467
+ "size_bytes": upload_result.get('size_bytes', 0)
468
+ })
469
+
470
+ print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)")
471
+ print(f"[INFO] Current assignments for chunk {chunk_id}: {len(state.chunk_assignments[chunk_id])} servers")
472
+ return True
473
 
474
  except Exception as e:
475
  print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}")
 
984
 
985
  @app.get("/chunks/{chunk_id}/status")
986
  async def get_chunk_status(chunk_id: int):
987
+ """Get detailed status and assignments of a specific chunk"""
988
  if chunk_id not in state.model_chunks:
989
  raise HTTPException(status_code=404, detail="Chunk not found")
990
 
991
  chunk = state.model_chunks[chunk_id]
992
+ assignments = state.chunk_assignments.get(chunk_id, [])
993
+
994
+ # Get current server status for each assignment
995
+ current_status = []
996
+ for assignment in assignments:
997
+ server_url = assignment["server_url"]
998
+ if server_url in state.tensor_servers:
999
+ server = state.tensor_servers[server_url]
1000
+ current_status.append({
1001
+ "server_url": server_url,
1002
+ "server_status": server.status,
1003
+ "last_heartbeat": server.last_heartbeat.isoformat(),
1004
+ "metrics": server.metrics.dict(),
1005
+ "assignment_time": assignment["timestamp"]
1006
+ })
1007
+
1008
  return {
1009
  "chunk_id": chunk_id,
1010
  "status": chunk.status,
1011
+ "size_bytes": chunk.size_bytes,
1012
+ "current_assignments": current_status,
1013
+ "assignment_history": [
1014
+ h for h in state.chunk_distribution_history
1015
+ if h["chunk_id"] == chunk_id
1016
+ ],
1017
+ "metrics": chunk.metrics,
1018
+ "config": chunk.config
1019
+ }
1020
+
1021
+ @app.get("/distribution/status")
1022
+ async def get_distribution_status():
1023
+ """Get overall distribution status of all chunks"""
1024
+ distribution_summary = {}
1025
+
1026
+ for chunk_id, chunk in state.model_chunks.items():
1027
+ assignments = state.chunk_assignments.get(chunk_id, [])
1028
+ active_servers = [
1029
+ a["server_url"] for a in assignments
1030
+ if a["server_url"] in state.tensor_servers and
1031
+ state.tensor_servers[a["server_url"]].status in ["ready", "busy"]
1032
+ ]
1033
+
1034
+ distribution_summary[chunk_id] = {
1035
+ "total_assignments": len(assignments),
1036
+ "active_servers": len(active_servers),
1037
+ "server_urls": active_servers,
1038
+ "size_bytes": chunk.size_bytes,
1039
+ "status": chunk.status
1040
+ }
1041
+
1042
+ return {
1043
+ "total_chunks": len(state.model_chunks),
1044
+ "total_servers": len(state.tensor_servers),
1045
+ "chunks": distribution_summary,
1046
+ "history": state.chunk_distribution_history[-10:] # Last 10 events
1047
  }
1048
 
1049
  @app.post("/initialize")
 
1153
  print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
1154
 
1155
  uvicorn.run(
1156
+ "controller_server_new:app",
1157
  host="0.0.0.0",
1158
  port=port,
1159
  reload=False