Fred808 commited on
Commit
dd87592
·
verified ·
1 Parent(s): 6b7d2ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -38
app.py CHANGED
@@ -148,16 +148,21 @@ async def split_model_weights():
148
  import shutil
149
  from pathlib import Path
150
 
151
- # Find model file (safetensors or pytorch)
152
- try:
153
- model_file = next(f for f in state.model_files.values() if f.endswith('.safetensors'))
154
- print(f"[INFO] Found safetensors file: {model_file}")
155
- except StopIteration:
156
- try:
157
- model_file = next(f for f in state.model_files.values() if f.endswith('.bin'))
158
- print(f"[INFO] Found PyTorch file: {model_file}")
159
- except StopIteration:
160
- raise Exception("No model weight files found")
 
 
 
 
 
161
 
162
  # Get file size and calculate chunks
163
  try:
@@ -690,35 +695,18 @@ async def download_model_files():
690
  try:
691
  print(f"[INFO] Downloading model files from {repo_id}...")
692
 
693
- # First download config and other small files
694
- config_files = ["config.json", "tokenizer_config.json", "vocab.json", "generation_config.json"]
695
- for filename in config_files:
696
- try:
697
- file_path = hf_hub_download(
698
- repo_id=repo_id,
699
- filename=filename,
700
- local_dir=model_path,
701
- force_download=True
702
- )
703
- print(f"[INFO] Downloaded {filename}")
704
- except Exception as e:
705
- print(f"[WARN] Could not download {filename}: {str(e)}")
706
 
707
- # Then download the model weights
708
- print("[INFO] Downloading model weights (this may take a while)...")
709
- for weight_file in ["pytorch_model.bin", "model.safetensors"]:
710
- try:
711
- file_path = hf_hub_download(
712
- repo_id=repo_id,
713
- filename=weight_file,
714
- local_dir=model_path,
715
- force_download=True
716
- )
717
- print(f"[INFO] Successfully downloaded {weight_file}")
718
- break # Stop after first successful weight file download
719
- except Exception as e:
720
- print(f"[WARN] Could not download {weight_file}: {str(e)}")
721
- continue
722
 
723
  print(f"[INFO] All files downloaded to {model_path}")
724
  state.is_model_loaded = True
 
148
  import shutil
149
  from pathlib import Path
150
 
151
+ # Find model weight files (safetensors or pytorch)
152
+ weight_files = [f for f in state.model_files.values() if f.endswith(('.safetensors', '.bin'))]
153
+
154
+ if not weight_files:
155
+ raise Exception("No model weight files found")
156
+
157
+ # The current splitting logic only supports splitting a single file.
158
+ # If there are multiple files, we assume they are sharded and need a different approach.
159
+ # For now, we will select the largest file to split, or the first one if all are small.
160
+ model_file = max(weight_files, key=os.path.getsize) if len(weight_files) > 1 else weight_files[0]
161
+
162
+ if len(weight_files) > 1:
163
+ print(f"[WARN] Found multiple weight files. Selecting the largest one for splitting: {model_file}")
164
+ else:
165
+ print(f"[INFO] Found model weight file: {model_file}")
166
 
167
  # Get file size and calculate chunks
168
  try:
 
695
  try:
696
  print(f"[INFO] Downloading model files from {repo_id}...")
697
 
698
+ # Use snapshot_download to get all necessary files at once, which supports all weight file names
699
+ print("[INFO] Downloading all model files (this may take a while)...")
 
 
 
 
 
 
 
 
 
 
 
700
 
701
+ # snapshot_download is the most robust way to get all files matching patterns
702
+ # This addresses the user's request to download model files that are not just "pytorch.bin"
703
+ model_path = snapshot_download(
704
+ repo_id=repo_id,
705
+ local_dir=model_path,
706
+ allow_patterns=["*.bin", "*.safetensors", "*.json", "*.txt", "tokenizer.model"],
707
+ ignore_patterns=["*.msgpack", "*.onnx"], # Ignore non-PyTorch/safetensors formats
708
+ force_download=True
709
+ )
 
 
 
 
 
 
710
 
711
  print(f"[INFO] All files downloaded to {model_path}")
712
  state.is_model_loaded = True