Fred808 commited on
Commit
73df3b2
·
verified ·
1 Parent(s): 5b3f4f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -34
app.py CHANGED
@@ -139,46 +139,94 @@ state = ControllerState()
139
 
140
  # ===== Helper Functions =====
141
  async def split_model_weights():
142
- """Split model weights into chunks based on available servers"""
143
  try:
144
- import torch
145
  import math
 
 
146
 
147
- # Install required packages if not present
148
- try:
149
- import safetensors
150
- except ImportError:
151
- print("[INFO] Installing required packages...")
152
- import subprocess
153
- subprocess.check_call(["pip", "install", "safetensors", "packaging"])
154
-
155
- # Load the full model weights
156
- import torch
157
- from safetensors.torch import load_file as load_safetensors
158
-
159
- # Try safetensors first with chunked loading, then fallback to pytorch
160
  try:
161
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
162
- print(f"[INFO] Loading weights from safetensors file: {model_file}")
163
- try:
164
- # Try direct loading first
165
- weights = load_safetensors(model_file)
166
- except Exception as e:
167
- if "header too large" in str(e):
168
- print("[INFO] Large header detected, attempting chunked loading...")
169
- from safetensors import safe_open
170
- weights = {}
171
- with safe_open(model_file, framework="pt") as f:
172
- for key in f.keys():
173
- weights[key] = f.get_tensor(key)
174
- print("[INFO] Successfully loaded weights using chunked loading")
175
- else:
176
- raise e
177
  except StopIteration:
178
- # No safetensors file found, try pytorch
179
- model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
180
- print(f"[INFO] Loading weights from PyTorch file: {model_file}")
181
- weights = torch.load(model_file, map_location='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  # Calculate total model size and chunks
184
  total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
 
139
 
140
  # ===== Helper Functions =====
141
  async def split_model_weights():
142
+ """Split model files into chunks based on available servers without loading into memory"""
143
  try:
144
+ import os
145
  import math
146
+ import shutil
147
+ from pathlib import Path
148
 
149
+ # Find model file (safetensors or pytorch)
 
 
 
 
 
 
 
 
 
 
 
 
150
  try:
151
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
152
+ print(f"[INFO] Found safetensors file: {model_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  except StopIteration:
154
+ try:
155
+ model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
156
+ print(f"[INFO] Found PyTorch file: {model_file}")
157
+ except StopIteration:
158
+ raise Exception("No model weight files found")
159
+
160
+ # Get file size and calculate chunks
161
+ file_size = os.path.getsize(model_file)
162
+ num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS)
163
+ num_chunks = num_servers # One chunk per server initially
164
+
165
+ chunk_size = math.ceil(file_size / num_chunks)
166
+ print(f"[INFO] Model file size: {file_size / (1024*1024*1024):.2f} GB")
167
+ print(f"[INFO] Creating {num_chunks} chunks of {chunk_size / (1024*1024):.2f} MB each")
168
+
169
+ # Create chunks directory if it doesn't exist
170
+ chunks_dir = os.path.join(os.path.dirname(model_file), "chunks")
171
+ os.makedirs(chunks_dir, exist_ok=True)
172
+
173
+ # Split the file into chunks
174
+ with open(model_file, 'rb') as f:
175
+ chunk_sizes = [] # Track actual chunk sizes
176
+ for chunk_id in range(num_chunks):
177
+ chunk_path = os.path.join(chunks_dir, f"chunk_{chunk_id}.bin")
178
+
179
+ # Calculate chunk boundaries
180
+ start_pos = chunk_id * chunk_size
181
+ remaining = file_size - start_pos
182
+ current_chunk_size = min(chunk_size, remaining)
183
+
184
+ if current_chunk_size <= 0:
185
+ break
186
+
187
+ # Read and write chunk
188
+ f.seek(start_pos)
189
+ chunk_data = f.read(current_chunk_size)
190
+
191
+ with open(chunk_path, 'wb') as chunk_file:
192
+ chunk_file.write(chunk_data)
193
+
194
+ chunk_sizes.append(current_chunk_size)
195
+
196
+ # Create chunk metadata
197
+ state.model_chunks[chunk_id] = ModelChunk(
198
+ chunk_id=chunk_id,
199
+ files=[f"chunk_{chunk_id}.bin"],
200
+ config={
201
+ "start_offset": start_pos,
202
+ "size_bytes": current_chunk_size,
203
+ "is_last_chunk": chunk_id == num_chunks - 1,
204
+ "total_chunks": num_chunks,
205
+ "original_file": os.path.basename(model_file)
206
+ },
207
+ size_bytes=current_chunk_size,
208
+ status="ready"
209
+ )
210
+
211
+ print(f"[INFO] Created chunk {chunk_id}: {current_chunk_size / (1024*1024):.2f} MB")
212
+
213
+ # Verify distribution
214
+ total_size_actual = sum(chunk_sizes)
215
+ if total_size_actual != file_size:
216
+ print(f"[WARN] Total chunk size ({total_size_actual}) differs from original file size ({file_size})")
217
+
218
+ print(f"\n[INFO] Distribution Summary:")
219
+ print(f"- Original file: {os.path.basename(model_file)}")
220
+ print(f"- Total size: {file_size / (1024*1024*1024):.2f} GB")
221
+ print(f"- Number of chunks: {len(state.model_chunks)}")
222
+ print(f"- Chunks directory: {chunks_dir}")
223
+ print(f"- Chunk size: {chunk_size / (1024*1024):.2f} MB")
224
+
225
+ return True
226
+
227
+ except Exception as e:
228
+ print(f"[ERROR] Failed to split model weights: {str(e)}")
229
+ return False
230
 
231
  # Calculate total model size and chunks
232
  total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())