refactor inference, warn if model is frozen
Browse files- scripts/finetune.py +13 -3
- src/axolotl/utils/models.py +6 -0
- src/axolotl/utils/trainer.py +1 -1
scripts/finetune.py
CHANGED
|
@@ -6,9 +6,11 @@ import random
|
|
| 6 |
import signal
|
| 7 |
import sys
|
| 8 |
from pathlib import Path
|
|
|
|
| 9 |
|
| 10 |
import fire
|
| 11 |
import torch
|
|
|
|
| 12 |
import yaml
|
| 13 |
from attrdict import AttrDefault
|
| 14 |
|
|
@@ -46,6 +48,15 @@ def choose_device(cfg):
|
|
| 46 |
cfg.device_map = {"": cfg.device}
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
| 50 |
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
| 51 |
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
|
@@ -55,8 +66,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
| 55 |
|
| 56 |
while True:
|
| 57 |
# support for multiline inputs
|
| 58 |
-
|
| 59 |
-
instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
| 60 |
if not instruction:
|
| 61 |
return
|
| 62 |
prompt = prompter_module().build_prompt(instruction=instruction)
|
|
@@ -66,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
| 66 |
with torch.no_grad():
|
| 67 |
# gc = GenerationConfig() # TODO swap out and use this
|
| 68 |
generated = model.generate(
|
| 69 |
-
inputs=batch["input_ids"].to(
|
| 70 |
do_sample=True,
|
| 71 |
use_cache=True,
|
| 72 |
repetition_penalty=1.1,
|
|
|
|
| 6 |
import signal
|
| 7 |
import sys
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Optional
|
| 10 |
|
| 11 |
import fire
|
| 12 |
import torch
|
| 13 |
+
import transformers
|
| 14 |
import yaml
|
| 15 |
from attrdict import AttrDefault
|
| 16 |
|
|
|
|
| 48 |
cfg.device_map = {"": cfg.device}
|
| 49 |
|
| 50 |
|
| 51 |
+
def get_multi_line_input() -> Optional[str]:
|
| 52 |
+
print("Give me an instruction (Ctrl + Z to finish): ")
|
| 53 |
+
instruction = ""
|
| 54 |
+
for line in sys.stdin:
|
| 55 |
+
instruction += line
|
| 56 |
+
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
| 57 |
+
return instruction
|
| 58 |
+
|
| 59 |
+
|
| 60 |
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
| 61 |
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
| 62 |
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
|
|
|
| 66 |
|
| 67 |
while True:
|
| 68 |
# support for multiline inputs
|
| 69 |
+
instruction = get_multi_line_input()
|
|
|
|
| 70 |
if not instruction:
|
| 71 |
return
|
| 72 |
prompt = prompter_module().build_prompt(instruction=instruction)
|
|
|
|
| 76 |
with torch.no_grad():
|
| 77 |
# gc = GenerationConfig() # TODO swap out and use this
|
| 78 |
generated = model.generate(
|
| 79 |
+
inputs=batch["input_ids"].to(cfg.device),
|
| 80 |
do_sample=True,
|
| 81 |
use_cache=True,
|
| 82 |
repetition_penalty=1.1,
|
src/axolotl/utils/models.py
CHANGED
|
@@ -183,6 +183,12 @@ def load_model(
|
|
| 183 |
model.is_parallelizable = True
|
| 184 |
model.model_parallel = True
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
# TODO resume_from_checkpoint handling
|
| 188 |
return model, tokenizer, lora_config
|
|
|
|
| 183 |
model.is_parallelizable = True
|
| 184 |
model.model_parallel = True
|
| 185 |
|
| 186 |
+
requires_grad = []
|
| 187 |
+
for name, param in model.named_parameters(recurse=True):
|
| 188 |
+
if param.requires_grad:
|
| 189 |
+
requires_grad.append(f"{name}: {param.requires_grad}")
|
| 190 |
+
if len(requires_grad) == 0:
|
| 191 |
+
logging.warning("there are no parameters that require gradient updates")
|
| 192 |
|
| 193 |
# TODO resume_from_checkpoint handling
|
| 194 |
return model, tokenizer, lora_config
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -105,7 +105,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 105 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
| 106 |
optim=cfg.optimizer if cfg.optimizer else None,
|
| 107 |
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
|
| 108 |
-
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
| 109 |
**training_arguments_kwargs,
|
| 110 |
)
|
| 111 |
|
|
|
|
| 105 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
| 106 |
optim=cfg.optimizer if cfg.optimizer else None,
|
| 107 |
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
|
| 108 |
+
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
| 109 |
**training_arguments_kwargs,
|
| 110 |
)
|
| 111 |
|