Fred808 commited on
Commit
19f193d
·
verified ·
1 Parent(s): d3e88f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -39
app.py CHANGED
@@ -19,12 +19,10 @@ class Settings:
19
 
20
  # List of tensor server URLs - should be actual IP addresses or hostnames
21
  TENSOR_SERVER_URLS = os.getenv("TENSOR_SERVER_URLS", "").split(",") or [
22
- "https://fred808-ilob.hf.space", # Example IP for tensor server 1
23
- "https://fred808-tserv.hf.space", # Example IP for tensor server 2
24
- "https://fred808-tserve2.hf.space" # Example IP for tensor server 3
25
  ]
26
-
27
- # Aggregator settings - should be actual IP or hostname
28
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
 
30
  # Model settings
@@ -142,31 +140,58 @@ async def split_model_weights():
142
  """Split model weights into chunks based on available servers"""
143
  try:
144
  import torch
 
145
 
146
  # Load the full model weights
147
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
148
  weights = torch.load(model_file, map_location='cpu')
149
 
150
- # Calculate chunks based on number of servers
151
- total_params = sum(p.numel() for p in weights.values())
152
  num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS)
153
- params_per_chunk = Settings.get_optimal_chunk_size(total_params, num_servers)
154
 
155
- print(f"[INFO] Total parameters: {total_params:,}")
 
 
 
 
 
 
156
  print(f"[INFO] Available servers: {num_servers}")
157
- print(f"[INFO] Parameters per chunk: {params_per_chunk:,}")
 
158
 
159
  current_chunk = []
160
- current_size = 0
161
  chunk_id = 0
 
 
 
 
 
 
 
 
162
 
163
  for key, tensor in weights.items():
164
  tensor_size = tensor.numel()
165
 
166
- if current_size + tensor_size > params_per_chunk and current_chunk:
 
 
 
 
 
 
167
  # Save current chunk
168
  chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
169
- torch.save({k: weights[k] for k in current_chunk}, chunk_path)
 
 
 
 
 
 
170
 
171
  # Create chunk metadata
172
  state.model_chunks[chunk_id] = ModelChunk(
@@ -174,41 +199,115 @@ async def split_model_weights():
174
  files=[f"chunk_{chunk_id}.safetensors"],
175
  config={
176
  "weight_keys": current_chunk,
177
- "input_size": weights[current_chunk[0]].size(1),
178
- "output_size": weights[current_chunk[-1]].size(0)
 
 
179
  }
180
  )
181
 
 
 
 
182
  # Reset for next chunk
183
  current_chunk = []
184
- current_size = 0
185
  chunk_id += 1
 
 
 
 
 
 
 
186
 
 
187
  current_chunk.append(key)
188
- current_size += tensor_size
189
 
190
  # Save last chunk if not empty
191
  if current_chunk:
192
  chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
193
- torch.save({k: weights[k] for k in current_chunk}, chunk_path)
 
 
 
 
 
 
194
 
195
  state.model_chunks[chunk_id] = ModelChunk(
196
  chunk_id=chunk_id,
197
  files=[f"chunk_{chunk_id}.safetensors"],
198
  config={
199
  "weight_keys": current_chunk,
 
 
200
  "input_size": weights[current_chunk[0]].size(1),
201
  "output_size": weights[current_chunk[-1]].size(0)
202
  }
203
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- print(f"[INFO] Split model into {len(state.model_chunks)} chunks")
206
  return True
207
 
208
  except Exception as e:
209
  print(f"[ERROR] Failed to split model weights: {str(e)}")
210
  return False
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  async def distribute_model_chunks():
213
  """Distribute model chunks across available tensor servers"""
214
  try:
@@ -789,28 +888,58 @@ async def startup_event():
789
  await initialize_system()
790
  print("[INFO] Model initialization complete")
791
 
792
- # If we have pre-configured tensor servers, try to connect to them
793
- if Settings.TENSOR_SERVER_URLS:
794
- print(f"[INFO] Attempting to connect to {len(Settings.TENSOR_SERVER_URLS)} pre-configured tensor servers...")
795
- for url in Settings.TENSOR_SERVER_URLS:
796
- try:
797
- if await check_tensor_server_health(url):
798
- state.tensor_servers[str(url)] = TensorServer(url=url)
799
- print(f"[INFO] Successfully registered pre-configured server at {url}")
800
- except Exception as e:
801
- print(f"[WARN] Failed to connect to pre-configured server {url}: {str(e)}")
802
-
803
- # If we have both model and servers, start distribution
804
- if state.is_model_loaded and state.tensor_servers:
805
- print("[INFO] Starting initial model distribution...")
 
 
 
 
 
806
  if await split_model_weights():
807
- print(f"[INFO] Split model into {len(state.model_chunks)} chunks")
808
- if await distribute_model_chunks():
809
- print("[INFO] Successfully completed initial distribution")
810
- else:
811
- print("[WARN] Initial distribution failed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812
  else:
813
- print("[WARN] Failed to split model weights")
 
 
 
814
  except Exception as e:
815
  print(f"[ERROR] Startup error: {str(e)}")
816
 
 
19
 
20
  # List of tensor server URLs - should be actual IP addresses or hostnames
21
  TENSOR_SERVER_URLS = os.getenv("TENSOR_SERVER_URLS", "").split(",") or [
22
+ "https://fred808-ilob.hf.space",
23
+ "https://fred808-tserv.hf.space",
24
+ "https://fred808-tserve2.hf.space"
25
  ]
 
 
26
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
27
 
28
  # Model settings
 
140
  """Split model weights into chunks based on available servers"""
141
  try:
142
  import torch
143
+ import math
144
 
145
  # Load the full model weights
146
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
147
  weights = torch.load(model_file, map_location='cpu')
148
 
149
+ # Calculate total model size and chunks
150
+ total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
151
  num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS)
 
152
 
153
+ # Determine optimal number of chunks based on server count
154
+ # If 2 servers -> 2 chunks (500MB each for 1GB)
155
+ # If 3 servers -> 3 chunks (333MB each for 1GB)
156
+ num_chunks = num_servers
157
+ bytes_per_chunk = math.ceil(total_size_bytes / num_chunks)
158
+
159
+ print(f"[INFO] Total model size: {total_size_bytes / (1024*1024*1024):.2f} GB")
160
  print(f"[INFO] Available servers: {num_servers}")
161
+ print(f"[INFO] Creating {num_chunks} chunks")
162
+ print(f"[INFO] Target chunk size: {bytes_per_chunk / (1024*1024):.2f} MB")
163
 
164
  current_chunk = []
165
+ current_chunk_size = 0
166
  chunk_id = 0
167
+ chunk_sizes = [] # Track actual chunk sizes for verification
168
+
169
+ # Sort weights by size for better distribution
170
+ sorted_weights = sorted(
171
+ weights.items(),
172
+ key=lambda x: x[1].nelement() * x[1].element_size(),
173
+ reverse=True
174
+ )
175
 
176
  for key, tensor in weights.items():
177
  tensor_size = tensor.numel()
178
 
179
+ # Calculate tensor size in bytes
180
+ tensor_size = tensor.nelement() * tensor.element_size()
181
+
182
+ # If adding this tensor would exceed chunk size and we have tensors in current chunk
183
+ if (current_chunk_size + tensor_size > bytes_per_chunk and current_chunk) or \
184
+ (chunk_id == num_chunks - 1): # Last chunk gets remaining tensors
185
+
186
  # Save current chunk
187
  chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
188
+ chunk_weights = {k: weights[k] for k in current_chunk}
189
+ torch.save(chunk_weights, chunk_path)
190
+
191
+ # Calculate chunk stats
192
+ chunk_total_size = sum(weights[k].nelement() * weights[k].element_size()
193
+ for k in current_chunk)
194
+ chunk_sizes.append(chunk_total_size)
195
 
196
  # Create chunk metadata
197
  state.model_chunks[chunk_id] = ModelChunk(
 
199
  files=[f"chunk_{chunk_id}.safetensors"],
200
  config={
201
  "weight_keys": current_chunk,
202
+ "size_bytes": chunk_total_size,
203
+ "num_parameters": sum(weights[k].nelement() for k in current_chunk),
204
+ "input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0,
205
+ "output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0
206
  }
207
  )
208
 
209
+ print(f"[INFO] Created chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, "
210
+ f"{len(current_chunk)} tensors")
211
+
212
  # Reset for next chunk
213
  current_chunk = []
214
+ current_chunk_size = 0
215
  chunk_id += 1
216
+
217
+ # If we've created all chunks except last one, put remaining tensors in last chunk
218
+ if chunk_id == num_chunks - 1:
219
+ remaining_tensors = [k for k, _ in sorted_weights if k not in sum([c.config["weight_keys"]
220
+ for c in state.model_chunks.values()], [])]
221
+ current_chunk.extend(remaining_tensors)
222
+ continue
223
 
224
+ # Add tensor to current chunk
225
  current_chunk.append(key)
226
+ current_chunk_size += tensor_size
227
 
228
  # Save last chunk if not empty
229
  if current_chunk:
230
  chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
231
+ chunk_weights = {k: weights[k] for k in current_chunk}
232
+ torch.save(chunk_weights, chunk_path)
233
+
234
+ # Calculate final chunk stats
235
+ chunk_total_size = sum(weights[k].nelement() * weights[k].element_size()
236
+ for k in current_chunk)
237
+ chunk_sizes.append(chunk_total_size)
238
 
239
  state.model_chunks[chunk_id] = ModelChunk(
240
  chunk_id=chunk_id,
241
  files=[f"chunk_{chunk_id}.safetensors"],
242
  config={
243
  "weight_keys": current_chunk,
244
+ "size_bytes": chunk_total_size,
245
+ "num_parameters": sum(weights[k].nelement() for k in current_chunk),
246
  "input_size": weights[current_chunk[0]].size(1),
247
  "output_size": weights[current_chunk[-1]].size(0)
248
  }
249
  )
250
+
251
+ print(f"[INFO] Created final chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, "
252
+ f"{len(current_chunk)} tensors")
253
+
254
+ # Verify distribution
255
+ total_size_actual = sum(chunk_sizes)
256
+ size_std_dev = torch.tensor(chunk_sizes).std().item() / (1024*1024) # MB
257
+ size_mean = torch.tensor(chunk_sizes).mean().item() / (1024*1024) # MB
258
+
259
+ print(f"\n[INFO] Distribution Summary:")
260
+ print(f"- Total model size: {total_size_actual / (1024*1024*1024):.2f} GB")
261
+ print(f"- Number of chunks: {len(state.model_chunks)}")
262
+ print(f"- Average chunk size: {size_mean:.2f} MB")
263
+ print(f"- Chunk size std dev: {size_std_dev:.2f} MB")
264
+ print(f"- Size variation: {(size_std_dev/size_mean*100):.1f}%")
265
+
266
+ # Verify all weights were distributed
267
+ all_distributed = set(sum([c.config["weight_keys"] for c in state.model_chunks.values()], []))
268
+ if len(all_distributed) != len(weights):
269
+ missing = set(weights.keys()) - all_distributed
270
+ print(f"[WARN] Some weights were not distributed: {missing}")
271
 
 
272
  return True
273
 
274
  except Exception as e:
275
  print(f"[ERROR] Failed to split model weights: {str(e)}")
276
  return False
277
 
278
+ async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict):
279
+ """Send a model chunk to a tensor server"""
280
+ try:
281
+ print(f"[INFO] Sending chunk {chunk_id} to server {server_url}")
282
+ chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
283
+
284
+ if not os.path.exists(chunk_path):
285
+ raise Exception(f"Chunk file not found: {chunk_path}")
286
+
287
+ chunk_data = {
288
+ 'chunk_id': chunk_id,
289
+ 'files': [f"chunk_{chunk_id}.safetensors"],
290
+ 'config': chunk_info['config']
291
+ }
292
+
293
+ async with aiohttp.ClientSession() as session:
294
+ async with session.post(
295
+ f"{server_url}/load_chunk",
296
+ json=chunk_data,
297
+ timeout=Settings.TENSOR_SERVER_TIMEOUT
298
+ ) as response:
299
+ if response.status != 200:
300
+ error_msg = await response.text()
301
+ raise Exception(f"Failed to load chunk: {error_msg}")
302
+
303
+ result = await response.json()
304
+ print(f"[INFO] Successfully loaded chunk {chunk_id} to {server_url}")
305
+ return True
306
+
307
+ except Exception as e:
308
+ print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}")
309
+ return False
310
+
311
  async def distribute_model_chunks():
312
  """Distribute model chunks across available tensor servers"""
313
  try:
 
888
  await initialize_system()
889
  print("[INFO] Model initialization complete")
890
 
891
+ # Try to connect to pre-configured tensor servers
892
+ connected_servers = []
893
+ print(f"[INFO] Attempting to connect to tensor servers...")
894
+ for url in Settings.TENSOR_SERVER_URLS:
895
+ try:
896
+ print(f"[INFO] Testing connection to {url}...")
897
+ if await check_tensor_server_health(url):
898
+ server = TensorServer(url=url)
899
+ state.tensor_servers[str(url)] = server
900
+ connected_servers.append(server)
901
+ print(f"[INFO] Successfully connected to tensor server at {url}")
902
+ except Exception as e:
903
+ print(f"[WARN] Failed to connect to tensor server {url}: {str(e)}")
904
+
905
+ if connected_servers:
906
+ print(f"[INFO] Connected to {len(connected_servers)} tensor servers")
907
+
908
+ # Split model into chunks
909
+ print("[INFO] Splitting model into chunks...")
910
  if await split_model_weights():
911
+ print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks")
912
+
913
+ # Actively distribute chunks to servers
914
+ print("[INFO] Starting chunk distribution...")
915
+ distribution_tasks = []
916
+
917
+ for chunk_id, chunk in state.model_chunks.items():
918
+ # Send each chunk to at least 2 servers if available
919
+ target_servers = connected_servers[:2]
920
+ for server in target_servers:
921
+ print(f"[INFO] Preparing to send chunk {chunk_id} to {server.url}")
922
+ task = asyncio.create_task(
923
+ send_chunk_to_server(str(server.url), chunk_id, chunk)
924
+ )
925
+ distribution_tasks.append(task)
926
+
927
+ # Update assignments
928
+ if str(server.url) not in chunk.server_assignments:
929
+ chunk.server_assignments.append(str(server.url))
930
+ if chunk_id not in server.model_chunks:
931
+ server.model_chunks.append(chunk_id)
932
+
933
+ if distribution_tasks:
934
+ print(f"[INFO] Waiting for {len(distribution_tasks)} distribution tasks to complete...")
935
+ results = await asyncio.gather(*distribution_tasks, return_exceptions=True)
936
+ success_count = sum(1 for r in results if r is True)
937
+ print(f"[INFO] Successfully distributed {success_count} chunks out of {len(distribution_tasks)} attempts")
938
  else:
939
+ print("[ERROR] Failed to split model weights")
940
+ else:
941
+ print("[WARN] No tensor servers available for distribution")
942
+
943
  except Exception as e:
944
  print(f"[ERROR] Startup error: {str(e)}")
945