Fred808 commited on
Commit
501d5b0
·
verified ·
1 Parent(s): 19f193d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -55
app.py CHANGED
@@ -18,10 +18,12 @@ class Settings:
18
  CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", "http://192.168.1.100:8000")
19
 
20
  # List of tensor server URLs - should be actual IP addresses or hostnames
21
- TENSOR_SERVER_URLS = os.getenv("TENSOR_SERVER_URLS", "").split(",") or [
 
 
22
  "https://fred808-ilob.hf.space",
23
- "https://fred808-tserv.hf.space",
24
- "https://fred808-tserve2.hf.space"
25
  ]
26
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
27
 
@@ -881,71 +883,49 @@ async def initialize_system():
881
  # ===== Main Execution =====
882
  @app.on_event("startup")
883
  async def startup_event():
884
- """Initialize the server and start background tasks"""
885
  print("[INFO] Initializing system...")
886
  try:
887
  # Initialize system and download model
888
  await initialize_system()
889
  print("[INFO] Model initialization complete")
890
 
891
- # Try to connect to pre-configured tensor servers
892
- connected_servers = []
893
- print(f"[INFO] Attempting to connect to tensor servers...")
894
- for url in Settings.TENSOR_SERVER_URLS:
895
- try:
896
- print(f"[INFO] Testing connection to {url}...")
897
- if await check_tensor_server_health(url):
898
- server = TensorServer(url=url)
899
- state.tensor_servers[str(url)] = server
900
- connected_servers.append(server)
901
- print(f"[INFO] Successfully connected to tensor server at {url}")
902
- except Exception as e:
903
- print(f"[WARN] Failed to connect to tensor server {url}: {str(e)}")
904
-
905
- if connected_servers:
906
- print(f"[INFO] Connected to {len(connected_servers)} tensor servers")
907
 
908
- # Split model into chunks
909
- print("[INFO] Splitting model into chunks...")
910
- if await split_model_weights():
911
- print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks")
912
-
913
- # Actively distribute chunks to servers
914
- print("[INFO] Starting chunk distribution...")
915
- distribution_tasks = []
916
-
917
- for chunk_id, chunk in state.model_chunks.items():
918
- # Send each chunk to at least 2 servers if available
919
- target_servers = connected_servers[:2]
920
- for server in target_servers:
921
- print(f"[INFO] Preparing to send chunk {chunk_id} to {server.url}")
922
- task = asyncio.create_task(
923
- send_chunk_to_server(str(server.url), chunk_id, chunk)
924
- )
925
- distribution_tasks.append(task)
926
-
927
- # Update assignments
928
- if str(server.url) not in chunk.server_assignments:
929
- chunk.server_assignments.append(str(server.url))
930
- if chunk_id not in server.model_chunks:
931
- server.model_chunks.append(chunk_id)
932
 
933
- if distribution_tasks:
934
- print(f"[INFO] Waiting for {len(distribution_tasks)} distribution tasks to complete...")
935
- results = await asyncio.gather(*distribution_tasks, return_exceptions=True)
936
- success_count = sum(1 for r in results if r is True)
937
- print(f"[INFO] Successfully distributed {success_count} chunks out of {len(distribution_tasks)} attempts")
938
- else:
939
- print("[ERROR] Failed to split model weights")
 
 
 
 
 
 
 
 
940
  else:
941
- print("[WARN] No tensor servers available for distribution")
942
 
943
  except Exception as e:
944
  print(f"[ERROR] Startup error: {str(e)}")
945
 
946
- # Start monitoring task
947
- asyncio.create_task(monitor_tensor_servers())
948
- print("[INFO] Server monitoring started")
949
 
950
  if __name__ == "__main__":
951
  port = int(os.getenv("PORT", 8000))
 
18
  CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", "http://192.168.1.100:8000")
19
 
20
  # List of tensor server URLs - should be actual IP addresses or hostnames
21
+ TENSOR_SERVER_URLS = [
22
+ url for url in os.getenv("TENSOR_SERVER_URLS", "").split(",") if url
23
+ ] or [
24
  "https://fred808-ilob.hf.space",
25
+ "https://fred808-tserv.hf.space",
26
+ "https://fred808-tserve2.hf.space",
27
  ]
28
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
 
 
883
  # ===== Main Execution =====
884
  @app.on_event("startup")
885
  async def startup_event():
886
+ """Initialize the server and start distribution"""
887
  print("[INFO] Initializing system...")
888
  try:
889
  # Initialize system and download model
890
  await initialize_system()
891
  print("[INFO] Model initialization complete")
892
 
893
+ # Split model into chunks
894
+ if await split_model_weights():
895
+ print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks")
 
 
 
 
 
 
 
 
 
 
 
 
 
896
 
897
+ # Distribute chunks to tensor servers
898
+ print("[INFO] Starting chunk distribution...")
899
+ distribution_tasks = []
900
+
901
+ # Round-robin distribution to tensor servers
902
+ for chunk_id, chunk in state.model_chunks.items():
903
+ # Determine target servers (distribute each chunk to 2 servers for redundancy)
904
+ server_indices = [i % len(Settings.TENSOR_SERVER_URLS) for i in range(chunk_id * 2, chunk_id * 2 + 2)]
905
+ target_servers = [Settings.TENSOR_SERVER_URLS[i] for i in server_indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
 
907
+ for server_url in target_servers:
908
+ print(f"[INFO] Sending chunk {chunk_id} to {server_url}")
909
+ task = asyncio.create_task(
910
+ send_chunk_to_server(server_url, chunk_id, chunk)
911
+ )
912
+ distribution_tasks.append(task)
913
+
914
+ # Track assignments for future reference
915
+ chunk.server_assignments.append(server_url)
916
+
917
+ if distribution_tasks:
918
+ print(f"[INFO] Distributing {len(distribution_tasks)} chunks...")
919
+ results = await asyncio.gather(*distribution_tasks, return_exceptions=True)
920
+ success_count = sum(1 for r in results if r is True)
921
+ print(f"[INFO] Successfully distributed {success_count} chunks out of {len(distribution_tasks)} attempts")
922
  else:
923
+ print("[ERROR] Failed to split model weights")
924
 
925
  except Exception as e:
926
  print(f"[ERROR] Startup error: {str(e)}")
927
 
928
+ print("[INFO] Startup complete")
 
 
929
 
930
  if __name__ == "__main__":
931
  port = int(os.getenv("PORT", 8000))