attempt to find linear modules for qlora
Browse files- src/axolotl/utils/models.py +24 -2
src/axolotl/utils/models.py
CHANGED
|
@@ -4,6 +4,7 @@ import os
|
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Optional, Tuple, TYPE_CHECKING
|
| 6 |
|
|
|
|
| 7 |
import torch
|
| 8 |
import transformers
|
| 9 |
from torch import nn
|
|
@@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg):
|
|
| 334 |
return model, peft_config
|
| 335 |
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
def load_lora(model, cfg):
|
| 338 |
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
| 339 |
|
|
@@ -343,12 +362,15 @@ def load_lora(model, cfg):
|
|
| 343 |
PeftModel,
|
| 344 |
)
|
| 345 |
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
lora_config = LoraConfig(
|
| 349 |
r=cfg.lora_r,
|
| 350 |
lora_alpha=cfg.lora_alpha,
|
| 351 |
-
target_modules=
|
| 352 |
lora_dropout=cfg.lora_dropout,
|
| 353 |
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
| 354 |
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Optional, Tuple, TYPE_CHECKING
|
| 6 |
|
| 7 |
+
import bitsandbytes as bnb
|
| 8 |
import torch
|
| 9 |
import transformers
|
| 10 |
from torch import nn
|
|
|
|
| 335 |
return model, peft_config
|
| 336 |
|
| 337 |
|
| 338 |
+
def find_all_linear_names(bits, model):
|
| 339 |
+
cls = (
|
| 340 |
+
bnb.nn.Linear4bit
|
| 341 |
+
if bits == 4
|
| 342 |
+
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
|
| 343 |
+
)
|
| 344 |
+
lora_module_names = set()
|
| 345 |
+
for name, module in model.named_modules():
|
| 346 |
+
if isinstance(module, cls):
|
| 347 |
+
names = name.split(".")
|
| 348 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
| 349 |
+
|
| 350 |
+
if "lm_head" in lora_module_names: # needed for 16-bit
|
| 351 |
+
lora_module_names.remove("lm_head")
|
| 352 |
+
|
| 353 |
+
return list(lora_module_names)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
def load_lora(model, cfg):
|
| 357 |
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
| 358 |
|
|
|
|
| 362 |
PeftModel,
|
| 363 |
)
|
| 364 |
|
| 365 |
+
bits = 4 if cfg.load_in_4bits else 8 if cfg.load_in_8bits else None
|
| 366 |
+
linear_names = find_all_linear_names(bits, model)
|
| 367 |
+
logging.info(f"found linear modules: {repr(linear_names)}")
|
| 368 |
+
lora_target_modules = cfg.lora_target_modules + linear_names
|
| 369 |
|
| 370 |
lora_config = LoraConfig(
|
| 371 |
r=cfg.lora_r,
|
| 372 |
lora_alpha=cfg.lora_alpha,
|
| 373 |
+
target_modules=lora_target_modules,
|
| 374 |
lora_dropout=cfg.lora_dropout,
|
| 375 |
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
| 376 |
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|