attempt xformers hijack attention
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -43,6 +43,10 @@ def load_model(
|
|
| 43 |
|
| 44 |
logging.info("patching with flash attention")
|
| 45 |
replace_llama_attn_with_flash_attn()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
| 48 |
try:
|
|
|
|
| 43 |
|
| 44 |
logging.info("patching with flash attention")
|
| 45 |
replace_llama_attn_with_flash_attn()
|
| 46 |
+
elif is_llama_derived_model and cfg.xformers_attention:
|
| 47 |
+
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
|
| 48 |
+
logging.info("patching with xformers attention")
|
| 49 |
+
hijack_llama_attention()
|
| 50 |
|
| 51 |
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
| 52 |
try:
|