Update app.py
Browse files
app.py
CHANGED
|
@@ -19,12 +19,10 @@ class Settings:
|
|
| 19 |
|
| 20 |
# List of tensor server URLs - should be actual IP addresses or hostnames
|
| 21 |
TENSOR_SERVER_URLS = os.getenv("TENSOR_SERVER_URLS", "").split(",") or [
|
| 22 |
-
"https://fred808-ilob.hf.space",
|
| 23 |
-
"https://fred808-tserv.hf.space",
|
| 24 |
-
"https://fred808-tserve2.hf.space"
|
| 25 |
]
|
| 26 |
-
|
| 27 |
-
# Aggregator settings - should be actual IP or hostname
|
| 28 |
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 29 |
|
| 30 |
# Model settings
|
|
@@ -142,31 +140,58 @@ async def split_model_weights():
|
|
| 142 |
"""Split model weights into chunks based on available servers"""
|
| 143 |
try:
|
| 144 |
import torch
|
|
|
|
| 145 |
|
| 146 |
# Load the full model weights
|
| 147 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
|
| 148 |
weights = torch.load(model_file, map_location='cpu')
|
| 149 |
|
| 150 |
-
# Calculate
|
| 151 |
-
|
| 152 |
num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS)
|
| 153 |
-
params_per_chunk = Settings.get_optimal_chunk_size(total_params, num_servers)
|
| 154 |
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
print(f"[INFO] Available servers: {num_servers}")
|
| 157 |
-
print(f"[INFO]
|
|
|
|
| 158 |
|
| 159 |
current_chunk = []
|
| 160 |
-
|
| 161 |
chunk_id = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
for key, tensor in weights.items():
|
| 164 |
tensor_size = tensor.numel()
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# Save current chunk
|
| 168 |
chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# Create chunk metadata
|
| 172 |
state.model_chunks[chunk_id] = ModelChunk(
|
|
@@ -174,41 +199,115 @@ async def split_model_weights():
|
|
| 174 |
files=[f"chunk_{chunk_id}.safetensors"],
|
| 175 |
config={
|
| 176 |
"weight_keys": current_chunk,
|
| 177 |
-
"
|
| 178 |
-
"
|
|
|
|
|
|
|
| 179 |
}
|
| 180 |
)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
# Reset for next chunk
|
| 183 |
current_chunk = []
|
| 184 |
-
|
| 185 |
chunk_id += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
|
|
|
| 187 |
current_chunk.append(key)
|
| 188 |
-
|
| 189 |
|
| 190 |
# Save last chunk if not empty
|
| 191 |
if current_chunk:
|
| 192 |
chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
state.model_chunks[chunk_id] = ModelChunk(
|
| 196 |
chunk_id=chunk_id,
|
| 197 |
files=[f"chunk_{chunk_id}.safetensors"],
|
| 198 |
config={
|
| 199 |
"weight_keys": current_chunk,
|
|
|
|
|
|
|
| 200 |
"input_size": weights[current_chunk[0]].size(1),
|
| 201 |
"output_size": weights[current_chunk[-1]].size(0)
|
| 202 |
}
|
| 203 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
print(f"[INFO] Split model into {len(state.model_chunks)} chunks")
|
| 206 |
return True
|
| 207 |
|
| 208 |
except Exception as e:
|
| 209 |
print(f"[ERROR] Failed to split model weights: {str(e)}")
|
| 210 |
return False
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
async def distribute_model_chunks():
|
| 213 |
"""Distribute model chunks across available tensor servers"""
|
| 214 |
try:
|
|
@@ -789,28 +888,58 @@ async def startup_event():
|
|
| 789 |
await initialize_system()
|
| 790 |
print("[INFO] Model initialization complete")
|
| 791 |
|
| 792 |
-
#
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
if await split_model_weights():
|
| 807 |
-
print(f"[INFO]
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
else:
|
| 813 |
-
print("[
|
|
|
|
|
|
|
|
|
|
| 814 |
except Exception as e:
|
| 815 |
print(f"[ERROR] Startup error: {str(e)}")
|
| 816 |
|
|
|
|
| 19 |
|
| 20 |
# List of tensor server URLs - should be actual IP addresses or hostnames
|
| 21 |
TENSOR_SERVER_URLS = os.getenv("TENSOR_SERVER_URLS", "").split(",") or [
|
| 22 |
+
"https://fred808-ilob.hf.space",
|
| 23 |
+
"https://fred808-tserv.hf.space",
|
| 24 |
+
"https://fred808-tserve2.hf.space"
|
| 25 |
]
|
|
|
|
|
|
|
| 26 |
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 27 |
|
| 28 |
# Model settings
|
|
|
|
| 140 |
"""Split model weights into chunks based on available servers"""
|
| 141 |
try:
|
| 142 |
import torch
|
| 143 |
+
import math
|
| 144 |
|
| 145 |
# Load the full model weights
|
| 146 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors') or f.endswith('.bin'))
|
| 147 |
weights = torch.load(model_file, map_location='cpu')
|
| 148 |
|
| 149 |
+
# Calculate total model size and chunks
|
| 150 |
+
total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values())
|
| 151 |
num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS)
|
|
|
|
| 152 |
|
| 153 |
+
# Determine optimal number of chunks based on server count
|
| 154 |
+
# If 2 servers -> 2 chunks (500MB each for 1GB)
|
| 155 |
+
# If 3 servers -> 3 chunks (333MB each for 1GB)
|
| 156 |
+
num_chunks = num_servers
|
| 157 |
+
bytes_per_chunk = math.ceil(total_size_bytes / num_chunks)
|
| 158 |
+
|
| 159 |
+
print(f"[INFO] Total model size: {total_size_bytes / (1024*1024*1024):.2f} GB")
|
| 160 |
print(f"[INFO] Available servers: {num_servers}")
|
| 161 |
+
print(f"[INFO] Creating {num_chunks} chunks")
|
| 162 |
+
print(f"[INFO] Target chunk size: {bytes_per_chunk / (1024*1024):.2f} MB")
|
| 163 |
|
| 164 |
current_chunk = []
|
| 165 |
+
current_chunk_size = 0
|
| 166 |
chunk_id = 0
|
| 167 |
+
chunk_sizes = [] # Track actual chunk sizes for verification
|
| 168 |
+
|
| 169 |
+
# Sort weights by size for better distribution
|
| 170 |
+
sorted_weights = sorted(
|
| 171 |
+
weights.items(),
|
| 172 |
+
key=lambda x: x[1].nelement() * x[1].element_size(),
|
| 173 |
+
reverse=True
|
| 174 |
+
)
|
| 175 |
|
| 176 |
for key, tensor in weights.items():
|
| 177 |
tensor_size = tensor.numel()
|
| 178 |
|
| 179 |
+
# Calculate tensor size in bytes
|
| 180 |
+
tensor_size = tensor.nelement() * tensor.element_size()
|
| 181 |
+
|
| 182 |
+
# If adding this tensor would exceed chunk size and we have tensors in current chunk
|
| 183 |
+
if (current_chunk_size + tensor_size > bytes_per_chunk and current_chunk) or \
|
| 184 |
+
(chunk_id == num_chunks - 1): # Last chunk gets remaining tensors
|
| 185 |
+
|
| 186 |
# Save current chunk
|
| 187 |
chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
|
| 188 |
+
chunk_weights = {k: weights[k] for k in current_chunk}
|
| 189 |
+
torch.save(chunk_weights, chunk_path)
|
| 190 |
+
|
| 191 |
+
# Calculate chunk stats
|
| 192 |
+
chunk_total_size = sum(weights[k].nelement() * weights[k].element_size()
|
| 193 |
+
for k in current_chunk)
|
| 194 |
+
chunk_sizes.append(chunk_total_size)
|
| 195 |
|
| 196 |
# Create chunk metadata
|
| 197 |
state.model_chunks[chunk_id] = ModelChunk(
|
|
|
|
| 199 |
files=[f"chunk_{chunk_id}.safetensors"],
|
| 200 |
config={
|
| 201 |
"weight_keys": current_chunk,
|
| 202 |
+
"size_bytes": chunk_total_size,
|
| 203 |
+
"num_parameters": sum(weights[k].nelement() for k in current_chunk),
|
| 204 |
+
"input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0,
|
| 205 |
+
"output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0
|
| 206 |
}
|
| 207 |
)
|
| 208 |
|
| 209 |
+
print(f"[INFO] Created chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, "
|
| 210 |
+
f"{len(current_chunk)} tensors")
|
| 211 |
+
|
| 212 |
# Reset for next chunk
|
| 213 |
current_chunk = []
|
| 214 |
+
current_chunk_size = 0
|
| 215 |
chunk_id += 1
|
| 216 |
+
|
| 217 |
+
# If we've created all chunks except last one, put remaining tensors in last chunk
|
| 218 |
+
if chunk_id == num_chunks - 1:
|
| 219 |
+
remaining_tensors = [k for k, _ in sorted_weights if k not in sum([c.config["weight_keys"]
|
| 220 |
+
for c in state.model_chunks.values()], [])]
|
| 221 |
+
current_chunk.extend(remaining_tensors)
|
| 222 |
+
continue
|
| 223 |
|
| 224 |
+
# Add tensor to current chunk
|
| 225 |
current_chunk.append(key)
|
| 226 |
+
current_chunk_size += tensor_size
|
| 227 |
|
| 228 |
# Save last chunk if not empty
|
| 229 |
if current_chunk:
|
| 230 |
chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
|
| 231 |
+
chunk_weights = {k: weights[k] for k in current_chunk}
|
| 232 |
+
torch.save(chunk_weights, chunk_path)
|
| 233 |
+
|
| 234 |
+
# Calculate final chunk stats
|
| 235 |
+
chunk_total_size = sum(weights[k].nelement() * weights[k].element_size()
|
| 236 |
+
for k in current_chunk)
|
| 237 |
+
chunk_sizes.append(chunk_total_size)
|
| 238 |
|
| 239 |
state.model_chunks[chunk_id] = ModelChunk(
|
| 240 |
chunk_id=chunk_id,
|
| 241 |
files=[f"chunk_{chunk_id}.safetensors"],
|
| 242 |
config={
|
| 243 |
"weight_keys": current_chunk,
|
| 244 |
+
"size_bytes": chunk_total_size,
|
| 245 |
+
"num_parameters": sum(weights[k].nelement() for k in current_chunk),
|
| 246 |
"input_size": weights[current_chunk[0]].size(1),
|
| 247 |
"output_size": weights[current_chunk[-1]].size(0)
|
| 248 |
}
|
| 249 |
)
|
| 250 |
+
|
| 251 |
+
print(f"[INFO] Created final chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, "
|
| 252 |
+
f"{len(current_chunk)} tensors")
|
| 253 |
+
|
| 254 |
+
# Verify distribution
|
| 255 |
+
total_size_actual = sum(chunk_sizes)
|
| 256 |
+
size_std_dev = torch.tensor(chunk_sizes).std().item() / (1024*1024) # MB
|
| 257 |
+
size_mean = torch.tensor(chunk_sizes).mean().item() / (1024*1024) # MB
|
| 258 |
+
|
| 259 |
+
print(f"\n[INFO] Distribution Summary:")
|
| 260 |
+
print(f"- Total model size: {total_size_actual / (1024*1024*1024):.2f} GB")
|
| 261 |
+
print(f"- Number of chunks: {len(state.model_chunks)}")
|
| 262 |
+
print(f"- Average chunk size: {size_mean:.2f} MB")
|
| 263 |
+
print(f"- Chunk size std dev: {size_std_dev:.2f} MB")
|
| 264 |
+
print(f"- Size variation: {(size_std_dev/size_mean*100):.1f}%")
|
| 265 |
+
|
| 266 |
+
# Verify all weights were distributed
|
| 267 |
+
all_distributed = set(sum([c.config["weight_keys"] for c in state.model_chunks.values()], []))
|
| 268 |
+
if len(all_distributed) != len(weights):
|
| 269 |
+
missing = set(weights.keys()) - all_distributed
|
| 270 |
+
print(f"[WARN] Some weights were not distributed: {missing}")
|
| 271 |
|
|
|
|
| 272 |
return True
|
| 273 |
|
| 274 |
except Exception as e:
|
| 275 |
print(f"[ERROR] Failed to split model weights: {str(e)}")
|
| 276 |
return False
|
| 277 |
|
| 278 |
+
async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict):
|
| 279 |
+
"""Send a model chunk to a tensor server"""
|
| 280 |
+
try:
|
| 281 |
+
print(f"[INFO] Sending chunk {chunk_id} to server {server_url}")
|
| 282 |
+
chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors")
|
| 283 |
+
|
| 284 |
+
if not os.path.exists(chunk_path):
|
| 285 |
+
raise Exception(f"Chunk file not found: {chunk_path}")
|
| 286 |
+
|
| 287 |
+
chunk_data = {
|
| 288 |
+
'chunk_id': chunk_id,
|
| 289 |
+
'files': [f"chunk_{chunk_id}.safetensors"],
|
| 290 |
+
'config': chunk_info['config']
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
async with aiohttp.ClientSession() as session:
|
| 294 |
+
async with session.post(
|
| 295 |
+
f"{server_url}/load_chunk",
|
| 296 |
+
json=chunk_data,
|
| 297 |
+
timeout=Settings.TENSOR_SERVER_TIMEOUT
|
| 298 |
+
) as response:
|
| 299 |
+
if response.status != 200:
|
| 300 |
+
error_msg = await response.text()
|
| 301 |
+
raise Exception(f"Failed to load chunk: {error_msg}")
|
| 302 |
+
|
| 303 |
+
result = await response.json()
|
| 304 |
+
print(f"[INFO] Successfully loaded chunk {chunk_id} to {server_url}")
|
| 305 |
+
return True
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}")
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
async def distribute_model_chunks():
|
| 312 |
"""Distribute model chunks across available tensor servers"""
|
| 313 |
try:
|
|
|
|
| 888 |
await initialize_system()
|
| 889 |
print("[INFO] Model initialization complete")
|
| 890 |
|
| 891 |
+
# Try to connect to pre-configured tensor servers
|
| 892 |
+
connected_servers = []
|
| 893 |
+
print(f"[INFO] Attempting to connect to tensor servers...")
|
| 894 |
+
for url in Settings.TENSOR_SERVER_URLS:
|
| 895 |
+
try:
|
| 896 |
+
print(f"[INFO] Testing connection to {url}...")
|
| 897 |
+
if await check_tensor_server_health(url):
|
| 898 |
+
server = TensorServer(url=url)
|
| 899 |
+
state.tensor_servers[str(url)] = server
|
| 900 |
+
connected_servers.append(server)
|
| 901 |
+
print(f"[INFO] Successfully connected to tensor server at {url}")
|
| 902 |
+
except Exception as e:
|
| 903 |
+
print(f"[WARN] Failed to connect to tensor server {url}: {str(e)}")
|
| 904 |
+
|
| 905 |
+
if connected_servers:
|
| 906 |
+
print(f"[INFO] Connected to {len(connected_servers)} tensor servers")
|
| 907 |
+
|
| 908 |
+
# Split model into chunks
|
| 909 |
+
print("[INFO] Splitting model into chunks...")
|
| 910 |
if await split_model_weights():
|
| 911 |
+
print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks")
|
| 912 |
+
|
| 913 |
+
# Actively distribute chunks to servers
|
| 914 |
+
print("[INFO] Starting chunk distribution...")
|
| 915 |
+
distribution_tasks = []
|
| 916 |
+
|
| 917 |
+
for chunk_id, chunk in state.model_chunks.items():
|
| 918 |
+
# Send each chunk to at least 2 servers if available
|
| 919 |
+
target_servers = connected_servers[:2]
|
| 920 |
+
for server in target_servers:
|
| 921 |
+
print(f"[INFO] Preparing to send chunk {chunk_id} to {server.url}")
|
| 922 |
+
task = asyncio.create_task(
|
| 923 |
+
send_chunk_to_server(str(server.url), chunk_id, chunk)
|
| 924 |
+
)
|
| 925 |
+
distribution_tasks.append(task)
|
| 926 |
+
|
| 927 |
+
# Update assignments
|
| 928 |
+
if str(server.url) not in chunk.server_assignments:
|
| 929 |
+
chunk.server_assignments.append(str(server.url))
|
| 930 |
+
if chunk_id not in server.model_chunks:
|
| 931 |
+
server.model_chunks.append(chunk_id)
|
| 932 |
+
|
| 933 |
+
if distribution_tasks:
|
| 934 |
+
print(f"[INFO] Waiting for {len(distribution_tasks)} distribution tasks to complete...")
|
| 935 |
+
results = await asyncio.gather(*distribution_tasks, return_exceptions=True)
|
| 936 |
+
success_count = sum(1 for r in results if r is True)
|
| 937 |
+
print(f"[INFO] Successfully distributed {success_count} chunks out of {len(distribution_tasks)} attempts")
|
| 938 |
else:
|
| 939 |
+
print("[ERROR] Failed to split model weights")
|
| 940 |
+
else:
|
| 941 |
+
print("[WARN] No tensor servers available for distribution")
|
| 942 |
+
|
| 943 |
except Exception as e:
|
| 944 |
print(f"[ERROR] Startup error: {str(e)}")
|
| 945 |
|