fix torch_dtype for model load
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -62,9 +62,12 @@ def load_model(
|
|
| 62 |
logging.info("patching with xformers attention")
|
| 63 |
hijack_llama_attention()
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
try:
|
| 69 |
if cfg.load_4bit:
|
| 70 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
|
|
|
| 62 |
logging.info("patching with xformers attention")
|
| 63 |
hijack_llama_attention()
|
| 64 |
|
| 65 |
+
if cfg.bf16:
|
| 66 |
+
torch_dtype = torch.bfloat16
|
| 67 |
+
elif cfg.load_in_8bit or cfg.fp16:
|
| 68 |
+
torch_dtype = torch.float16
|
| 69 |
+
else:
|
| 70 |
+
torch_dtype = torch.float32
|
| 71 |
try:
|
| 72 |
if cfg.load_4bit:
|
| 73 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|