add gptneox embeddings, fix phi2 inputs, also fix the casting (#1083)
Browse files
src/axolotl/utils/lora_embeddings.py
CHANGED
|
@@ -8,5 +8,7 @@ def get_linear_embedding_layers(model_type):
|
|
| 8 |
returns the linear embedding layers needed for loras, dependent on the model arch
|
| 9 |
"""
|
| 10 |
if model_type == "phi-msft":
|
| 11 |
-
return ["embd", "lm_head.linear"]
|
| 12 |
-
|
|
|
|
|
|
|
|
|
| 8 |
returns the linear embedding layers needed for loras, dependent on the model arch
|
| 9 |
"""
|
| 10 |
if model_type == "phi-msft":
|
| 11 |
+
return ["embd.wte", "lm_head.linear"]
|
| 12 |
+
if model_type == "gpt_neox":
|
| 13 |
+
return ["embed_in", "embed_out"]
|
| 14 |
+
return ["embed_tokens", "lm_head"]
|
src/axolotl/utils/models.py
CHANGED
|
@@ -588,13 +588,14 @@ def load_model(
|
|
| 588 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
| 589 |
|
| 590 |
# make sure these are fp32 per Ramesh et al. (2021)
|
|
|
|
| 591 |
for name, module in model.named_modules():
|
| 592 |
if "norm" in name:
|
| 593 |
module.to(torch.float32)
|
| 594 |
if model_config.model_type == "btlm":
|
| 595 |
# don't upcast lm_head for btlm
|
| 596 |
continue
|
| 597 |
-
if
|
| 598 |
if hasattr(module, "weight"):
|
| 599 |
module.to(torch.float32)
|
| 600 |
|
|
@@ -619,15 +620,12 @@ def load_model(
|
|
| 619 |
|
| 620 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
| 621 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 622 |
-
if needs_fa2_dtype or
|
| 623 |
-
cfg.flash_attention
|
| 624 |
-
and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
|
| 625 |
-
):
|
| 626 |
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
| 627 |
for name, module in model.named_modules():
|
| 628 |
if "norm" in name:
|
| 629 |
module.to(cfg.torch_dtype)
|
| 630 |
-
if
|
| 631 |
if hasattr(module, "weight"):
|
| 632 |
module.to(cfg.torch_dtype)
|
| 633 |
|
|
|
|
| 588 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
| 589 |
|
| 590 |
# make sure these are fp32 per Ramesh et al. (2021)
|
| 591 |
+
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
| 592 |
for name, module in model.named_modules():
|
| 593 |
if "norm" in name:
|
| 594 |
module.to(torch.float32)
|
| 595 |
if model_config.model_type == "btlm":
|
| 596 |
# don't upcast lm_head for btlm
|
| 597 |
continue
|
| 598 |
+
if any(m in name for m in embedding_modules):
|
| 599 |
if hasattr(module, "weight"):
|
| 600 |
module.to(torch.float32)
|
| 601 |
|
|
|
|
| 620 |
|
| 621 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
| 622 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 623 |
+
if needs_fa2_dtype or cfg.flash_attention:
|
|
|
|
|
|
|
|
|
|
| 624 |
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
| 625 |
for name, module in model.named_modules():
|
| 626 |
if "norm" in name:
|
| 627 |
module.to(cfg.torch_dtype)
|
| 628 |
+
if any(m in name for m in embedding_modules):
|
| 629 |
if hasattr(module, "weight"):
|
| 630 |
module.to(cfg.torch_dtype)
|
| 631 |
|
tests/core/test_trainer_builder.py
CHANGED
|
@@ -30,6 +30,7 @@ def fixture_cfg():
|
|
| 30 |
"adam_epsilon": 0.00001,
|
| 31 |
"dataloader_num_workers": 1,
|
| 32 |
"dataloader_pin_memory": True,
|
|
|
|
| 33 |
}
|
| 34 |
)
|
| 35 |
|
|
|
|
| 30 |
"adam_epsilon": 0.00001,
|
| 31 |
"dataloader_num_workers": 1,
|
| 32 |
"dataloader_pin_memory": True,
|
| 33 |
+
"model_config_type": "llama",
|
| 34 |
}
|
| 35 |
)
|
| 36 |
|
tests/test_validation.py
CHANGED
|
@@ -770,7 +770,7 @@ class ValidationCheckModelConfig(BaseValidation):
|
|
| 770 |
"adapter": "qlora",
|
| 771 |
"load_in_4bit": True,
|
| 772 |
"tokens": ["<|imstart|>"],
|
| 773 |
-
"lora_modules_to_save": ["embd", "lm_head.linear"],
|
| 774 |
}
|
| 775 |
)
|
| 776 |
|
|
|
|
| 770 |
"adapter": "qlora",
|
| 771 |
"load_in_4bit": True,
|
| 772 |
"tokens": ["<|imstart|>"],
|
| 773 |
+
"lora_modules_to_save": ["embd.wte", "lm_head.linear"],
|
| 774 |
}
|
| 775 |
)
|
| 776 |
|