imrpove llama check and fix safetensors file check
Browse files- scripts/finetune.py +2 -4
scripts/finetune.py
CHANGED
|
@@ -85,14 +85,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|
| 85 |
raise e
|
| 86 |
|
| 87 |
try:
|
| 88 |
-
if cfg.load_4bit and "llama" in base_model:
|
| 89 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
| 90 |
from huggingface_hub import snapshot_download
|
| 91 |
|
| 92 |
cache_model_path = Path(snapshot_download(base_model))
|
| 93 |
-
|
| 94 |
-
cache_model_path.glob("*.pt")
|
| 95 |
-
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin'))
|
| 96 |
if len(files) > 0:
|
| 97 |
model_path = str(files[0])
|
| 98 |
else:
|
|
|
|
| 85 |
raise e
|
| 86 |
|
| 87 |
try:
|
| 88 |
+
if cfg.load_4bit and "llama" in base_model or "llama" in cfg.model_type.lower():
|
| 89 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
| 90 |
from huggingface_hub import snapshot_download
|
| 91 |
|
| 92 |
cache_model_path = Path(snapshot_download(base_model))
|
| 93 |
+
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin'))
|
|
|
|
|
|
|
| 94 |
if len(files) > 0:
|
| 95 |
model_path = str(files[0])
|
| 96 |
else:
|