|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This script trains a model on a small text dataset and measures the memory consumption, as well as a few other |
|
|
useful metrics. |
|
|
|
|
|
Example: |
|
|
|
|
|
Get help: |
|
|
|
|
|
```bash |
|
|
python train_memory.py --help |
|
|
``` |
|
|
|
|
|
Train the google/gemma-2-2b model with a LoRA config json at the indicated location. |
|
|
|
|
|
```bash |
|
|
python train_memory.py "google/gemma-2-2b" --max_seq_length 256 --batch_size 1 --rank 32 --dtype bfloat16 --path_config <path-to-adapter-config.json> |
|
|
``` |
|
|
|
|
|
Fully fine-tune the model (i.e. without LoRA) by setting the rank to 0: |
|
|
|
|
|
```bash |
|
|
python train_memory.py "google/gemma-2-2b" --rank 0 |
|
|
``` |
|
|
|
|
|
Get an estimate of the size of the hidden states by passing `--monitor_tensors`. This trains just for a single epoch. For realistic estimates, the batch size for this: |
|
|
|
|
|
```bash |
|
|
python train_memory.py "google/gemma-2-2b" --max_seq_length 256 --batch_size 32 --rank 32 --dtype bfloat16 --path_config configs/lora_rank-32_embedding-lora/ --monitor_tensors |
|
|
``` |
|
|
|
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import gc |
|
|
import os |
|
|
import sys |
|
|
import tempfile |
|
|
import time |
|
|
import warnings |
|
|
from collections import Counter |
|
|
from contextlib import nullcontext |
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from torch import nn |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
BitsAndBytesConfig, |
|
|
) |
|
|
|
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME |
|
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" |
|
|
dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5} |
|
|
|
|
|
|
|
|
def init_accelerator(): |
|
|
torch.manual_seed(0) |
|
|
if device == "cpu": |
|
|
return |
|
|
|
|
|
device_module = getattr(torch, device, torch.cuda) |
|
|
device_module.reset_peak_memory_stats() |
|
|
device_module.manual_seed_all(0) |
|
|
|
|
|
nn.Linear(1, 1).to(device) |
|
|
|
|
|
|
|
|
def get_data(tokenizer): |
|
|
def tokenize(samples): |
|
|
|
|
|
|
|
|
tokenized = tokenizer(samples["quote"]) |
|
|
tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]] |
|
|
tokenized["attention_mask"] = [ |
|
|
input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"] |
|
|
] |
|
|
return tokenized |
|
|
|
|
|
data = load_dataset("ybelkada/english_quotes_copy") |
|
|
data = data.map(tokenize, batched=True) |
|
|
|
|
|
|
|
|
|
|
|
data = data.remove_columns(["quote", "author", "tags"]) |
|
|
return data |
|
|
|
|
|
|
|
|
def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, max_steps, path_config): |
|
|
init_accelerator() |
|
|
device_module = getattr(torch, device, torch.cuda) |
|
|
accelerator_memory_init = device_module.max_memory_allocated() |
|
|
accelerator_memory_log = [] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
tokenizer.model_max_length = max_seq_length |
|
|
if not tokenizer.pad_token: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
data = get_data(tokenizer) |
|
|
|
|
|
if dtype == "int4": |
|
|
quant_config = BitsAndBytesConfig(load_in_4bit=True) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, quantization_config=quant_config) |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
elif dtype == "int8": |
|
|
quant_config = BitsAndBytesConfig(load_in_8bit=True) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, quantization_config=quant_config) |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
elif dtype == "bfloat16": |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16) |
|
|
elif dtype == "float16": |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.float16) |
|
|
elif dtype == "float32": |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device) |
|
|
else: |
|
|
raise ValueError(f"Invalid dtype: {dtype}") |
|
|
|
|
|
if rank > 0: |
|
|
if path_config is None: |
|
|
raise RuntimeError("LoRA rank > 0 requires a path to a LoRA config") |
|
|
if path_config.endswith(CONFIG_NAME): |
|
|
path_config = path_config.removesuffix(CONFIG_NAME) |
|
|
config = LoraConfig.from_pretrained(path_config) |
|
|
model = get_peft_model(model, config) |
|
|
model.print_trainable_parameters() |
|
|
else: |
|
|
print("Not using LoRA") |
|
|
|
|
|
model.config.use_cache = False |
|
|
storage = [] |
|
|
|
|
|
def pack(x): |
|
|
storage.append(x) |
|
|
return len(storage) - 1 |
|
|
|
|
|
def unpack(x): |
|
|
return storage[x] |
|
|
|
|
|
train_ctx = partial(torch.autograd.graph.saved_tensors_hooks, pack, unpack) if monitor_tensors else nullcontext |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) |
|
|
losses = [] |
|
|
sample = 0 |
|
|
tic_total = time.perf_counter() |
|
|
for i in range(0, max_steps): |
|
|
storage.clear() |
|
|
tic = time.perf_counter() |
|
|
try: |
|
|
batch = tokenizer.pad(data["train"][sample : sample + batch_size], return_tensors="pt").to(model.device) |
|
|
sample += batch_size |
|
|
|
|
|
|
|
|
batch["labels"] = batch["input_ids"].clone() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
with train_ctx(): |
|
|
outputs = model(**batch) |
|
|
loss = outputs.loss |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
losses.append(loss.item()) |
|
|
accelerator_memory_log.append(device_module.memory_allocated() - accelerator_memory_init) |
|
|
device_module.empty_cache() |
|
|
gc.collect() |
|
|
toc = time.perf_counter() |
|
|
print(f"step {i:3d} loss {loss.item():.6f} time {toc - tic:.2f}s", file=sys.stderr) |
|
|
except KeyboardInterrupt: |
|
|
print("canceled training") |
|
|
break |
|
|
|
|
|
if monitor_tensors: |
|
|
break |
|
|
|
|
|
toc_total = time.perf_counter() |
|
|
|
|
|
accelerator_memory_final = device_module.max_memory_allocated() |
|
|
accelerator_memory_avg = int(sum(accelerator_memory_log) / len(accelerator_memory_log)) |
|
|
print(f"{model.device.type} memory avg: {accelerator_memory_avg // 2**20}MB") |
|
|
print(f"{model.device.type} memory max: {(accelerator_memory_final - accelerator_memory_init) // 2**20}MB") |
|
|
print(f"total time: {toc_total - tic_total:.2f}s") |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
model.save_pretrained(tmp_dir) |
|
|
stat = os.stat(os.path.join(tmp_dir, SAFETENSORS_WEIGHTS_NAME)) |
|
|
file_size = stat.st_size |
|
|
print(f"file size: {file_size / 2**20:.1f}MB") |
|
|
|
|
|
if monitor_tensors: |
|
|
dtype_counts = Counter(t.dtype for t in storage) |
|
|
shape_counts = Counter(t.shape for t in storage) |
|
|
param_shape_counts = Counter(p.shape for p in model.parameters()) |
|
|
param_shape_counts_copy = dict(param_shape_counts).copy() |
|
|
|
|
|
|
|
|
|
|
|
diff_shape_counts = {} |
|
|
for shape, count in shape_counts.items(): |
|
|
if shape in param_shape_counts_copy: |
|
|
diff_count = count - param_shape_counts[shape] |
|
|
if diff_count > 0: |
|
|
diff_shape_counts[shape] = diff_count |
|
|
param_shape_counts_copy[shape] = max(0, param_shape_counts_copy[shape] - diff_count) |
|
|
elif shape[::-1] in param_shape_counts: |
|
|
diff_count = count - param_shape_counts[shape[::-1]] |
|
|
if diff_count > 0: |
|
|
diff_shape_counts[shape] = diff_count |
|
|
param_shape_counts_copy[shape[::-1]] = max(0, param_shape_counts_copy[shape[::-1]] - diff_count) |
|
|
else: |
|
|
diff_shape_counts[shape] = count |
|
|
|
|
|
total_size = sum(t.numel() * t.element_size() for t in storage) |
|
|
total_size_mb = f"{total_size // 2**20}MB" |
|
|
diff_size = 0 |
|
|
for shape, count in diff_shape_counts.items(): |
|
|
diff_size += count * torch.zeros(shape).numel() * dtype_to_bytes_linear[dtype] |
|
|
param_size = total_size - diff_size |
|
|
|
|
|
diff_size_mb = f"{diff_size // 2**20}MB" |
|
|
param_size_mb = f"{param_size // 2**20}MB" |
|
|
|
|
|
print(f"Dtype counts: {dtype_counts.most_common()}") |
|
|
print(f"Total size of tensors: {total_size_mb: >12}") |
|
|
print(f"Total size of activations: {diff_size_mb: >12}") |
|
|
print(f"Total size of parameters: {param_size_mb: >12}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("model_id", type=str, help="Model name on Hugging Face Hub") |
|
|
parser.add_argument("--rank", type=int, default=8, help="Rank of LoRA, 0 => no LoRA, default 8") |
|
|
parser.add_argument( |
|
|
"--dtype", |
|
|
type=str, |
|
|
default="float32", |
|
|
help="Data type, one of float32, float16, bfloat16, int8, int4, default float32", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--monitor_tensors", |
|
|
action="store_true", |
|
|
help="Monitor tensor sizes during training for a single training step, off by default", |
|
|
) |
|
|
parser.add_argument("--max_seq_length", type=int, default=128, help="Maximum sequence length, default 128") |
|
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size, default 1") |
|
|
parser.add_argument("--max_steps", type=int, default=50, help="Maximum number of training steps, default 50") |
|
|
parser.add_argument("--path_config", type=str, default=None, help="Path to LoRA config") |
|
|
args = parser.parse_args() |
|
|
train( |
|
|
model_id=args.model_id, |
|
|
rank=args.rank, |
|
|
dtype=args.dtype, |
|
|
monitor_tensors=args.monitor_tensors, |
|
|
max_seq_length=args.max_seq_length, |
|
|
batch_size=args.batch_size, |
|
|
max_steps=args.max_steps, |
|
|
path_config=args.path_config, |
|
|
) |
|
|
|