Update src/axolotl/utils/models.py
Browse filesCo-authored-by: Aman Gupta Karmani <[email protected]>
src/axolotl/utils/models.py
CHANGED
|
@@ -368,7 +368,7 @@ def load_model(
|
|
| 368 |
|
| 369 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
| 370 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 371 |
-
if
|
| 372 |
cfg.flash_attention and cfg.is_llama_derived_model
|
| 373 |
):
|
| 374 |
for name, module in model.named_modules():
|
|
|
|
| 368 |
|
| 369 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
| 370 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 371 |
+
if fix_dtype and (
|
| 372 |
cfg.flash_attention and cfg.is_llama_derived_model
|
| 373 |
):
|
| 374 |
for name, module in model.named_modules():
|