Merge pull request #336 from tmm1/flash-attn
Browse filesFix flash-attn + qlora not working with llama models
src/axolotl/{flash_attn.py → monkeypatch/llama_attn_hijack_flash.py}
RENAMED
|
File without changes
|
src/axolotl/utils/models.py
CHANGED
|
@@ -92,7 +92,9 @@ def load_model(
|
|
| 92 |
|
| 93 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 94 |
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
| 95 |
-
from axolotl.
|
|
|
|
|
|
|
| 96 |
|
| 97 |
LOG.info("patching with flash attention")
|
| 98 |
replace_llama_attn_with_flash_attn()
|
|
@@ -331,6 +333,16 @@ def load_model(
|
|
| 331 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
| 332 |
)
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
model, lora_config = load_adapter(model, cfg, adapter)
|
| 335 |
|
| 336 |
if cfg.ddp and not load_in_8bit:
|
|
@@ -407,14 +419,6 @@ def load_llama_adapter(model, cfg):
|
|
| 407 |
else:
|
| 408 |
model = get_peft_model(model, peft_config)
|
| 409 |
|
| 410 |
-
if cfg.flash_attention:
|
| 411 |
-
for name, module in model.named_modules():
|
| 412 |
-
if "norm" in name:
|
| 413 |
-
module.to(torch.float16)
|
| 414 |
-
if "lm_head" in name or "embed_tokens" in name:
|
| 415 |
-
if hasattr(module, "weight"):
|
| 416 |
-
module.to(torch.float16)
|
| 417 |
-
|
| 418 |
model.print_trainable_parameters()
|
| 419 |
|
| 420 |
return model, peft_config
|
|
|
|
| 92 |
|
| 93 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 94 |
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
| 95 |
+
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
| 96 |
+
replace_llama_attn_with_flash_attn,
|
| 97 |
+
)
|
| 98 |
|
| 99 |
LOG.info("patching with flash attention")
|
| 100 |
replace_llama_attn_with_flash_attn()
|
|
|
|
| 333 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
| 334 |
)
|
| 335 |
|
| 336 |
+
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
| 337 |
+
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 338 |
+
if cfg.flash_attention and cfg.is_llama_derived_model:
|
| 339 |
+
for name, module in model.named_modules():
|
| 340 |
+
if "norm" in name:
|
| 341 |
+
module.to(torch_dtype)
|
| 342 |
+
if "lm_head" in name or "embed_tokens" in name:
|
| 343 |
+
if hasattr(module, "weight"):
|
| 344 |
+
module.to(torch_dtype)
|
| 345 |
+
|
| 346 |
model, lora_config = load_adapter(model, cfg, adapter)
|
| 347 |
|
| 348 |
if cfg.ddp and not load_in_8bit:
|
|
|
|
| 419 |
else:
|
| 420 |
model = get_peft_model(model, peft_config)
|
| 421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
model.print_trainable_parameters()
|
| 423 |
|
| 424 |
return model, peft_config
|