Fred808 commited on
Commit
5b3f4f4
·
verified ·
1 Parent(s): 85a62a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -3
app.py CHANGED
@@ -28,7 +28,7 @@ class Settings:
28
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
 
30
  # Model settings
31
- MODEL_REPO = "https://huggingface.co/microsoft/florence-2-large"
32
 
33
  # Server settings
34
  TENSOR_SERVER_TIMEOUT = 30 # seconds
@@ -144,15 +144,36 @@ async def split_model_weights():
144
  import torch
145
  import math
146
 
 
 
 
 
 
 
 
 
147
  # Load the full model weights
148
  import torch
149
  from safetensors.torch import load_file as load_safetensors
150
 
151
- # Try safetensors first, then fallback to pytorch
152
  try:
153
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
154
  print(f"[INFO] Loading weights from safetensors file: {model_file}")
155
- weights = load_safetensors(model_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  except StopIteration:
157
  # No safetensors file found, try pytorch
158
  model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
 
28
  AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002")
29
 
30
  # Model settings
31
+ MODEL_REPO = "https://huggingface.co/facebook/opt-125m"
32
 
33
  # Server settings
34
  TENSOR_SERVER_TIMEOUT = 30 # seconds
 
144
  import torch
145
  import math
146
 
147
+ # Install required packages if not present
148
+ try:
149
+ import safetensors
150
+ except ImportError:
151
+ print("[INFO] Installing required packages...")
152
+ import subprocess
153
+ subprocess.check_call(["pip", "install", "safetensors", "packaging"])
154
+
155
  # Load the full model weights
156
  import torch
157
  from safetensors.torch import load_file as load_safetensors
158
 
159
+ # Try safetensors first with chunked loading, then fallback to pytorch
160
  try:
161
  model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
162
  print(f"[INFO] Loading weights from safetensors file: {model_file}")
163
+ try:
164
+ # Try direct loading first
165
+ weights = load_safetensors(model_file)
166
+ except Exception as e:
167
+ if "header too large" in str(e):
168
+ print("[INFO] Large header detected, attempting chunked loading...")
169
+ from safetensors import safe_open
170
+ weights = {}
171
+ with safe_open(model_file, framework="pt") as f:
172
+ for key in f.keys():
173
+ weights[key] = f.get_tensor(key)
174
+ print("[INFO] Successfully loaded weights using chunked loading")
175
+ else:
176
+ raise e
177
  except StopIteration:
178
  # No safetensors file found, try pytorch
179
  model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))