Update app.py
Browse files
app.py
CHANGED
|
@@ -139,46 +139,94 @@ state = ControllerState()
|
|
| 139 |
|
| 140 |
# ===== Helper Functions =====
|
| 141 |
async def split_model_weights():
|
| 142 |
-
"""Split model
|
| 143 |
try:
|
| 144 |
-
import
|
| 145 |
import math
|
|
|
|
|
|
|
| 146 |
|
| 147 |
-
#
|
| 148 |
-
try:
|
| 149 |
-
import safetensors
|
| 150 |
-
except ImportError:
|
| 151 |
-
print("[INFO] Installing required packages...")
|
| 152 |
-
import subprocess
|
| 153 |
-
subprocess.check_call(["pip", "install", "safetensors", "packaging"])
|
| 154 |
-
|
| 155 |
-
# Load the full model weights
|
| 156 |
-
import torch
|
| 157 |
-
from safetensors.torch import load_file as load_safetensors
|
| 158 |
-
|
| 159 |
-
# Try safetensors first with chunked loading, then fallback to pytorch
|
| 160 |
try:
|
| 161 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
|
| 162 |
-
print(f"[INFO]
|
| 163 |
-
try:
|
| 164 |
-
# Try direct loading first
|
| 165 |
-
weights = load_safetensors(model_file)
|
| 166 |
-
except Exception as e:
|
| 167 |
-
if "header too large" in str(e):
|
| 168 |
-
print("[INFO] Large header detected, attempting chunked loading...")
|
| 169 |
-
from safetensors import safe_open
|
| 170 |
-
weights = {}
|
| 171 |
-
with safe_open(model_file, framework="pt") as f:
|
| 172 |
-
for key in f.keys():
|
| 173 |
-
weights[key] = f.get_tensor(key)
|
| 174 |
-
print("[INFO] Successfully loaded weights using chunked loading")
|
| 175 |
-
else:
|
| 176 |
-
raise e
|
| 177 |
except StopIteration:
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# Calculate total model size and chunks
|
| 184 |
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|
|
|
|
| 139 |
|
| 140 |
# ===== Helper Functions =====
|
| 141 |
async def split_model_weights():
|
| 142 |
+
"""Split model files into chunks based on available servers without loading into memory"""
|
| 143 |
try:
|
| 144 |
+
import os
|
| 145 |
import math
|
| 146 |
+
import shutil
|
| 147 |
+
from pathlib import Path
|
| 148 |
|
| 149 |
+
# Find model file (safetensors or pytorch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
try:
|
| 151 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
|
| 152 |
+
print(f"[INFO] Found safetensors file: {model_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
except StopIteration:
|
| 154 |
+
try:
|
| 155 |
+
model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
|
| 156 |
+
print(f"[INFO] Found PyTorch file: {model_file}")
|
| 157 |
+
except StopIteration:
|
| 158 |
+
raise Exception("No model weight files found")
|
| 159 |
+
|
| 160 |
+
# Get file size and calculate chunks
|
| 161 |
+
file_size = os.path.getsize(model_file)
|
| 162 |
+
num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS)
|
| 163 |
+
num_chunks = num_servers # One chunk per server initially
|
| 164 |
+
|
| 165 |
+
chunk_size = math.ceil(file_size / num_chunks)
|
| 166 |
+
print(f"[INFO] Model file size: {file_size / (1024*1024*1024):.2f} GB")
|
| 167 |
+
print(f"[INFO] Creating {num_chunks} chunks of {chunk_size / (1024*1024):.2f} MB each")
|
| 168 |
+
|
| 169 |
+
# Create chunks directory if it doesn't exist
|
| 170 |
+
chunks_dir = os.path.join(os.path.dirname(model_file), "chunks")
|
| 171 |
+
os.makedirs(chunks_dir, exist_ok=True)
|
| 172 |
+
|
| 173 |
+
# Split the file into chunks
|
| 174 |
+
with open(model_file, 'rb') as f:
|
| 175 |
+
chunk_sizes = [] # Track actual chunk sizes
|
| 176 |
+
for chunk_id in range(num_chunks):
|
| 177 |
+
chunk_path = os.path.join(chunks_dir, f"chunk_{chunk_id}.bin")
|
| 178 |
+
|
| 179 |
+
# Calculate chunk boundaries
|
| 180 |
+
start_pos = chunk_id * chunk_size
|
| 181 |
+
remaining = file_size - start_pos
|
| 182 |
+
current_chunk_size = min(chunk_size, remaining)
|
| 183 |
+
|
| 184 |
+
if current_chunk_size <= 0:
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
# Read and write chunk
|
| 188 |
+
f.seek(start_pos)
|
| 189 |
+
chunk_data = f.read(current_chunk_size)
|
| 190 |
+
|
| 191 |
+
with open(chunk_path, 'wb') as chunk_file:
|
| 192 |
+
chunk_file.write(chunk_data)
|
| 193 |
+
|
| 194 |
+
chunk_sizes.append(current_chunk_size)
|
| 195 |
+
|
| 196 |
+
# Create chunk metadata
|
| 197 |
+
state.model_chunks[chunk_id] = ModelChunk(
|
| 198 |
+
chunk_id=chunk_id,
|
| 199 |
+
files=[f"chunk_{chunk_id}.bin"],
|
| 200 |
+
config={
|
| 201 |
+
"start_offset": start_pos,
|
| 202 |
+
"size_bytes": current_chunk_size,
|
| 203 |
+
"is_last_chunk": chunk_id == num_chunks - 1,
|
| 204 |
+
"total_chunks": num_chunks,
|
| 205 |
+
"original_file": os.path.basename(model_file)
|
| 206 |
+
},
|
| 207 |
+
size_bytes=current_chunk_size,
|
| 208 |
+
status="ready"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
print(f"[INFO] Created chunk {chunk_id}: {current_chunk_size / (1024*1024):.2f} MB")
|
| 212 |
+
|
| 213 |
+
# Verify distribution
|
| 214 |
+
total_size_actual = sum(chunk_sizes)
|
| 215 |
+
if total_size_actual != file_size:
|
| 216 |
+
print(f"[WARN] Total chunk size ({total_size_actual}) differs from original file size ({file_size})")
|
| 217 |
+
|
| 218 |
+
print(f"\n[INFO] Distribution Summary:")
|
| 219 |
+
print(f"- Original file: {os.path.basename(model_file)}")
|
| 220 |
+
print(f"- Total size: {file_size / (1024*1024*1024):.2f} GB")
|
| 221 |
+
print(f"- Number of chunks: {len(state.model_chunks)}")
|
| 222 |
+
print(f"- Chunks directory: {chunks_dir}")
|
| 223 |
+
print(f"- Chunk size: {chunk_size / (1024*1024):.2f} MB")
|
| 224 |
+
|
| 225 |
+
return True
|
| 226 |
+
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f"[ERROR] Failed to split model weights: {str(e)}")
|
| 229 |
+
return False
|
| 230 |
|
| 231 |
# Calculate total model size and chunks
|
| 232 |
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|