Update app.py
Browse files
app.py
CHANGED
|
@@ -28,7 +28,7 @@ class Settings:
|
|
| 28 |
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 29 |
|
| 30 |
# Model settings
|
| 31 |
-
MODEL_REPO = "https://huggingface.co/
|
| 32 |
|
| 33 |
# Server settings
|
| 34 |
TENSOR_SERVER_TIMEOUT = 30 # seconds
|
|
@@ -144,15 +144,36 @@ async def split_model_weights():
|
|
| 144 |
import torch
|
| 145 |
import math
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
# Load the full model weights
|
| 148 |
import torch
|
| 149 |
from safetensors.torch import load_file as load_safetensors
|
| 150 |
|
| 151 |
-
# Try safetensors first, then fallback to pytorch
|
| 152 |
try:
|
| 153 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
|
| 154 |
print(f"[INFO] Loading weights from safetensors file: {model_file}")
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
except StopIteration:
|
| 157 |
# No safetensors file found, try pytorch
|
| 158 |
model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
|
|
|
|
| 28 |
AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
|
| 29 |
|
| 30 |
# Model settings
|
| 31 |
+
MODEL_REPO = "https://huggingface.co/facebook/opt-125m"
|
| 32 |
|
| 33 |
# Server settings
|
| 34 |
TENSOR_SERVER_TIMEOUT = 30 # seconds
|
|
|
|
| 144 |
import torch
|
| 145 |
import math
|
| 146 |
|
| 147 |
+
# Install required packages if not present
|
| 148 |
+
try:
|
| 149 |
+
import safetensors
|
| 150 |
+
except ImportError:
|
| 151 |
+
print("[INFO] Installing required packages...")
|
| 152 |
+
import subprocess
|
| 153 |
+
subprocess.check_call(["pip", "install", "safetensors", "packaging"])
|
| 154 |
+
|
| 155 |
# Load the full model weights
|
| 156 |
import torch
|
| 157 |
from safetensors.torch import load_file as load_safetensors
|
| 158 |
|
| 159 |
+
# Try safetensors first with chunked loading, then fallback to pytorch
|
| 160 |
try:
|
| 161 |
model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
|
| 162 |
print(f"[INFO] Loading weights from safetensors file: {model_file}")
|
| 163 |
+
try:
|
| 164 |
+
# Try direct loading first
|
| 165 |
+
weights = load_safetensors(model_file)
|
| 166 |
+
except Exception as e:
|
| 167 |
+
if "header too large" in str(e):
|
| 168 |
+
print("[INFO] Large header detected, attempting chunked loading...")
|
| 169 |
+
from safetensors import safe_open
|
| 170 |
+
weights = {}
|
| 171 |
+
with safe_open(model_file, framework="pt") as f:
|
| 172 |
+
for key in f.keys():
|
| 173 |
+
weights[key] = f.get_tensor(key)
|
| 174 |
+
print("[INFO] Successfully loaded weights using chunked loading")
|
| 175 |
+
else:
|
| 176 |
+
raise e
|
| 177 |
except StopIteration:
|
| 178 |
# No safetensors file found, try pytorch
|
| 179 |
model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
|