Fred808 commited on
Commit
5819251
·
verified ·
1 Parent(s): f3a8698

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -96
app.py CHANGED
@@ -136,39 +136,26 @@ class ControllerState:
136
  self.chunks_dir: str = "" # Directory containing chunk files
137
  self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
138
  self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
139
- self.chunk_assignments: Dict[int, List[Dict[str, any]]] = {} # Track which chunks are on which servers
140
- self.chunk_distribution_history: List[Dict[str, any]] = [] # Track distribution history with timestamps
141
 
142
  state = ControllerState()
143
 
144
  # ===== Helper Functions =====
145
  async def split_model_weights():
146
- """Split model files into chunks and convert to safetensors format"""
147
  try:
148
  import os
149
  import math
150
  import shutil
151
- import torch
152
- from safetensors.torch import save_file, load_file
153
  from pathlib import Path
154
 
155
- # Find model file and convert to safetensors if needed
156
  try:
157
- model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
158
- print(f"[INFO] Found PyTorch file: {model_file}")
159
-
160
- # Convert to safetensors
161
- print("[INFO] Converting model to safetensors format...")
162
- weights = torch.load(model_file, map_location='cpu')
163
- safetensors_path = os.path.join(state.model_path, "model.safetensors")
164
- save_file(weights, safetensors_path)
165
- model_file = safetensors_path
166
- print(f"[INFO] Converted model to safetensors format: {model_file}")
167
-
168
  except StopIteration:
169
  try:
170
- model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
171
- print(f"[INFO] Found existing safetensors file: {model_file}")
172
  except StopIteration:
173
  raise Exception("No model weight files found")
174
 
@@ -444,31 +431,7 @@ async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict)
444
  raise Exception(f"Failed to upload chunk data: {error_msg}")
445
 
446
  upload_result = await upload_response.json()
447
-
448
- # Track the assignment
449
- if chunk_id not in state.chunk_assignments:
450
- state.chunk_assignments[chunk_id] = []
451
-
452
- assignment = {
453
- "server_url": server_url,
454
- "timestamp": datetime.now().isoformat(),
455
- "status": "loaded",
456
- "size_bytes": upload_result.get('size_bytes', 0)
457
- }
458
- state.chunk_assignments[chunk_id].append(assignment)
459
-
460
- # Add to history
461
- state.chunk_distribution_history.append({
462
- "chunk_id": chunk_id,
463
- "server_url": server_url,
464
- "timestamp": datetime.now().isoformat(),
465
- "action": "upload",
466
- "status": "success",
467
- "size_bytes": upload_result.get('size_bytes', 0)
468
- })
469
-
470
  print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)")
471
- print(f"[INFO] Current assignments for chunk {chunk_id}: {len(state.chunk_assignments[chunk_id])} servers")
472
  return True
473
 
474
  except Exception as e:
@@ -984,66 +947,16 @@ async def redistribute_chunks():
984
 
985
  @app.get("/chunks/{chunk_id}/status")
986
  async def get_chunk_status(chunk_id: int):
987
- """Get detailed status and assignments of a specific chunk"""
988
  if chunk_id not in state.model_chunks:
989
  raise HTTPException(status_code=404, detail="Chunk not found")
990
 
991
  chunk = state.model_chunks[chunk_id]
992
- assignments = state.chunk_assignments.get(chunk_id, [])
993
-
994
- # Get current server status for each assignment
995
- current_status = []
996
- for assignment in assignments:
997
- server_url = assignment["server_url"]
998
- if server_url in state.tensor_servers:
999
- server = state.tensor_servers[server_url]
1000
- current_status.append({
1001
- "server_url": server_url,
1002
- "server_status": server.status,
1003
- "last_heartbeat": server.last_heartbeat.isoformat(),
1004
- "metrics": server.metrics.dict(),
1005
- "assignment_time": assignment["timestamp"]
1006
- })
1007
-
1008
  return {
1009
  "chunk_id": chunk_id,
1010
  "status": chunk.status,
1011
- "size_bytes": chunk.size_bytes,
1012
- "current_assignments": current_status,
1013
- "assignment_history": [
1014
- h for h in state.chunk_distribution_history
1015
- if h["chunk_id"] == chunk_id
1016
- ],
1017
- "metrics": chunk.metrics,
1018
- "config": chunk.config
1019
- }
1020
-
1021
- @app.get("/distribution/status")
1022
- async def get_distribution_status():
1023
- """Get overall distribution status of all chunks"""
1024
- distribution_summary = {}
1025
-
1026
- for chunk_id, chunk in state.model_chunks.items():
1027
- assignments = state.chunk_assignments.get(chunk_id, [])
1028
- active_servers = [
1029
- a["server_url"] for a in assignments
1030
- if a["server_url"] in state.tensor_servers and
1031
- state.tensor_servers[a["server_url"]].status in ["ready", "busy"]
1032
- ]
1033
-
1034
- distribution_summary[chunk_id] = {
1035
- "total_assignments": len(assignments),
1036
- "active_servers": len(active_servers),
1037
- "server_urls": active_servers,
1038
- "size_bytes": chunk.size_bytes,
1039
- "status": chunk.status
1040
- }
1041
-
1042
- return {
1043
- "total_chunks": len(state.model_chunks),
1044
- "total_servers": len(state.tensor_servers),
1045
- "chunks": distribution_summary,
1046
- "history": state.chunk_distribution_history[-10:] # Last 10 events
1047
  }
1048
 
1049
  @app.post("/initialize")
 
136
  self.chunks_dir: str = "" # Directory containing chunk files
137
  self.operation_results: Dict[str, Dict] = {} # Track operation results from tensor servers
138
  self.pending_operations: Dict[str, asyncio.Task] = {} # Track ongoing operations
 
 
139
 
140
  state = ControllerState()
141
 
142
  # ===== Helper Functions =====
143
  async def split_model_weights():
144
+ """Split model files into chunks based on available servers without loading into memory"""
145
  try:
146
  import os
147
  import math
148
  import shutil
 
 
149
  from pathlib import Path
150
 
151
+ # Find model file (safetensors or pytorch)
152
  try:
153
+ model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
154
+ print(f"[INFO] Found safetensors file: {model_file}")
 
 
 
 
 
 
 
 
 
155
  except StopIteration:
156
  try:
157
+ model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
158
+ print(f"[INFO] Found PyTorch file: {model_file}")
159
  except StopIteration:
160
  raise Exception("No model weight files found")
161
 
 
431
  raise Exception(f"Failed to upload chunk data: {error_msg}")
432
 
433
  upload_result = await upload_response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)")
 
435
  return True
436
 
437
  except Exception as e:
 
947
 
948
  @app.get("/chunks/{chunk_id}/status")
949
  async def get_chunk_status(chunk_id: int):
950
+ """Get the status and assignments of a specific chunk"""
951
  if chunk_id not in state.model_chunks:
952
  raise HTTPException(status_code=404, detail="Chunk not found")
953
 
954
  chunk = state.model_chunks[chunk_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
955
  return {
956
  "chunk_id": chunk_id,
957
  "status": chunk.status,
958
+ "server_assignments": chunk.server_assignments,
959
+ "metrics": chunk.metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
  }
961
 
962
  @app.post("/initialize")