Update app.py
Browse files
app.py
CHANGED
|
@@ -148,16 +148,21 @@ async def split_model_weights():
|
|
| 148 |
import shutil
|
| 149 |
from pathlib import Path
|
| 150 |
|
| 151 |
-
# Find model
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# Get file size and calculate chunks
|
| 163 |
try:
|
|
@@ -690,35 +695,18 @@ async def download_model_files():
|
|
| 690 |
try:
|
| 691 |
print(f"[INFO] Downloading model files from {repo_id}...")
|
| 692 |
|
| 693 |
-
#
|
| 694 |
-
|
| 695 |
-
for filename in config_files:
|
| 696 |
-
try:
|
| 697 |
-
file_path = hf_hub_download(
|
| 698 |
-
repo_id=repo_id,
|
| 699 |
-
filename=filename,
|
| 700 |
-
local_dir=model_path,
|
| 701 |
-
force_download=True
|
| 702 |
-
)
|
| 703 |
-
print(f"[INFO] Downloaded {filename}")
|
| 704 |
-
except Exception as e:
|
| 705 |
-
print(f"[WARN] Could not download {filename}: {str(e)}")
|
| 706 |
|
| 707 |
-
#
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
)
|
| 717 |
-
print(f"[INFO] Successfully downloaded {weight_file}")
|
| 718 |
-
break # Stop after first successful weight file download
|
| 719 |
-
except Exception as e:
|
| 720 |
-
print(f"[WARN] Could not download {weight_file}: {str(e)}")
|
| 721 |
-
continue
|
| 722 |
|
| 723 |
print(f"[INFO] All files downloaded to {model_path}")
|
| 724 |
state.is_model_loaded = True
|
|
|
|
| 148 |
import shutil
|
| 149 |
from pathlib import Path
|
| 150 |
|
| 151 |
+
# Find model weight files (safetensors or pytorch)
|
| 152 |
+
weight_files = [f for f in state.model_files.values() if f.endswith(('.safetensors', '.bin'))]
|
| 153 |
+
|
| 154 |
+
if not weight_files:
|
| 155 |
+
raise Exception("No model weight files found")
|
| 156 |
+
|
| 157 |
+
# The current splitting logic only supports splitting a single file.
|
| 158 |
+
# If there are multiple files, we assume they are sharded and need a different approach.
|
| 159 |
+
# For now, we will select the largest file to split, or the first one if all are small.
|
| 160 |
+
model_file = max(weight_files, key=os.path.getsize) if len(weight_files) > 1 else weight_files[0]
|
| 161 |
+
|
| 162 |
+
if len(weight_files) > 1:
|
| 163 |
+
print(f"[WARN] Found multiple weight files. Selecting the largest one for splitting: {model_file}")
|
| 164 |
+
else:
|
| 165 |
+
print(f"[INFO] Found model weight file: {model_file}")
|
| 166 |
|
| 167 |
# Get file size and calculate chunks
|
| 168 |
try:
|
|
|
|
| 695 |
try:
|
| 696 |
print(f"[INFO] Downloading model files from {repo_id}...")
|
| 697 |
|
| 698 |
+
# Use snapshot_download to get all necessary files at once, which supports all weight file names
|
| 699 |
+
print("[INFO] Downloading all model files (this may take a while)...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
|
| 701 |
+
# snapshot_download is the most robust way to get all files matching patterns
|
| 702 |
+
# This addresses the user's request to download model files that are not just "pytorch.bin"
|
| 703 |
+
model_path = snapshot_download(
|
| 704 |
+
repo_id=repo_id,
|
| 705 |
+
local_dir=model_path,
|
| 706 |
+
allow_patterns=["*.bin", "*.safetensors", "*.json", "*.txt", "tokenizer.model"],
|
| 707 |
+
ignore_patterns=["*.msgpack", "*.onnx"], # Ignore non-PyTorch/safetensors formats
|
| 708 |
+
force_download=True
|
| 709 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
|
| 711 |
print(f"[INFO] All files downloaded to {model_path}")
|
| 712 |
state.is_model_loaded = True
|