Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- nanochat/__init__.py +0 -0
- nanochat/adamw.py +77 -0
- nanochat/checkpoint_manager.py +146 -0
- nanochat/common.py +136 -0
- nanochat/configurator.py +56 -0
- nanochat/core_eval.py +262 -0
- nanochat/dataloader.py +49 -0
- nanochat/dataset.py +128 -0
- nanochat/engine.py +343 -0
- nanochat/execution.py +350 -0
- nanochat/gpt.py +322 -0
- nanochat/logo.svg +8 -0
- nanochat/loss_eval.py +63 -0
- nanochat/muon.py +187 -0
- nanochat/report.py +404 -0
- nanochat/tokenizer.py +395 -0
- nanochat/ui.html +394 -0
nanochat/__init__.py
ADDED
|
File without changes
|
nanochat/adamw.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
|
| 3 |
+
Not a general optimizer! But works for our specific use.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DistAdamW(torch.optim.Optimizer):
|
| 11 |
+
"""
|
| 12 |
+
Distributed AdamW optimizer.
|
| 13 |
+
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
|
| 16 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
| 17 |
+
super().__init__(param_groups, defaults)
|
| 18 |
+
|
| 19 |
+
@torch.compile
|
| 20 |
+
@torch.no_grad()
|
| 21 |
+
def step(self):
|
| 22 |
+
rank = dist.get_rank()
|
| 23 |
+
world_size = dist.get_world_size()
|
| 24 |
+
reduce_scatter_futures: list[torch.Future] = []
|
| 25 |
+
all_reduce_futures: list[torch.Future] = []
|
| 26 |
+
grad_slices = []
|
| 27 |
+
for group in self.param_groups:
|
| 28 |
+
params: list[Tensor] = group["params"]
|
| 29 |
+
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
|
| 30 |
+
for base_i in range(len(params)):
|
| 31 |
+
grad = params[base_i].grad
|
| 32 |
+
rank_size = grad.shape[0] // world_size
|
| 33 |
+
grad_slice = torch.empty_like(grad[:rank_size])
|
| 34 |
+
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
|
| 35 |
+
grad_slices.append(grad_slice)
|
| 36 |
+
|
| 37 |
+
idx = 0
|
| 38 |
+
for group in self.param_groups:
|
| 39 |
+
beta1, beta2 = group['betas']
|
| 40 |
+
eps = group['eps']
|
| 41 |
+
wd = group['weight_decay']
|
| 42 |
+
params = group['params']
|
| 43 |
+
for base in range(len(params)):
|
| 44 |
+
reduce_scatter_futures[idx].wait()
|
| 45 |
+
p = params[base]
|
| 46 |
+
rank_size = p.shape[0] // world_size
|
| 47 |
+
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
| 48 |
+
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
|
| 49 |
+
state = self.state[p]
|
| 50 |
+
g_slice = grad_slices[idx]
|
| 51 |
+
# State init
|
| 52 |
+
if not state:
|
| 53 |
+
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
|
| 54 |
+
state['exp_avg'] = torch.zeros_like(p_slice)
|
| 55 |
+
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
| 56 |
+
exp_avg = state['exp_avg']
|
| 57 |
+
exp_avg_sq = state['exp_avg_sq']
|
| 58 |
+
state['step'] += 1
|
| 59 |
+
t = state['step']
|
| 60 |
+
# weight decay
|
| 61 |
+
if wd != 0:
|
| 62 |
+
eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
|
| 63 |
+
p_slice.mul_(1 - eff_weight_decay)
|
| 64 |
+
# update running averages
|
| 65 |
+
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
|
| 66 |
+
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
|
| 67 |
+
# bias corrections
|
| 68 |
+
bias1 = 1 - beta1 ** t
|
| 69 |
+
bias2 = 1 - beta2 ** t
|
| 70 |
+
# compute step
|
| 71 |
+
denom = exp_avg_sq.sqrt().add_(eps)
|
| 72 |
+
step_size = lr * (torch.sqrt(bias2) / bias1)
|
| 73 |
+
update = exp_avg.div(denom).mul_(step_size)
|
| 74 |
+
p_slice.add_(other=update, alpha=-1.0)
|
| 75 |
+
idx += 1
|
| 76 |
+
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
| 77 |
+
torch.futures.collect_all(all_reduce_futures).wait()
|
nanochat/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for saving and loading model/optim/state checkpoints.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import glob
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from nanochat.common import get_base_dir
|
| 12 |
+
from nanochat.gpt import GPT, GPTConfig
|
| 13 |
+
from nanochat.tokenizer import get_tokenizer
|
| 14 |
+
from nanochat.common import setup_default_logging
|
| 15 |
+
|
| 16 |
+
# Set up logging
|
| 17 |
+
setup_default_logging()
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
def log0(message):
|
| 20 |
+
if int(os.environ.get('RANK', 0)) == 0:
|
| 21 |
+
logger.info(message)
|
| 22 |
+
|
| 23 |
+
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
| 24 |
+
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
| 25 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 26 |
+
# Save the model state (parameters)
|
| 27 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 28 |
+
torch.save(model_data, model_path)
|
| 29 |
+
log0(f"Saved model file to: {model_path}")
|
| 30 |
+
# Save the optimizer state (useful for SFT or any other fine-tuning)
|
| 31 |
+
if optimizer_data is not None:
|
| 32 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
| 33 |
+
torch.save(optimizer_data, optimizer_path)
|
| 34 |
+
log0(f"Saved optimizer file to: {optimizer_path}")
|
| 35 |
+
# Save the metadata dict as json
|
| 36 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 37 |
+
with open(meta_path, "w") as f:
|
| 38 |
+
json.dump(meta_data, f, indent=2)
|
| 39 |
+
log0(f"Saved metadata file to: {meta_path}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
| 43 |
+
# Load the model state
|
| 44 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 45 |
+
model_data = torch.load(model_path, map_location=device)
|
| 46 |
+
# Load the optimizer state if requested
|
| 47 |
+
optimizer_data = None
|
| 48 |
+
if load_optimizer:
|
| 49 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
| 50 |
+
optimizer_data = torch.load(optimizer_path, map_location=device)
|
| 51 |
+
# Load the metadata
|
| 52 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 53 |
+
with open(meta_path, "r") as f:
|
| 54 |
+
meta_data = json.load(f)
|
| 55 |
+
return model_data, optimizer_data, meta_data
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_model(checkpoint_dir, step, device, phase):
|
| 59 |
+
"""
|
| 60 |
+
A bunch of repetitive code to build a model from a given checkpoint.
|
| 61 |
+
Returns:
|
| 62 |
+
- base model - uncompiled, not wrapped in DDP
|
| 63 |
+
- tokenizer
|
| 64 |
+
- meta data saved during base model training
|
| 65 |
+
"""
|
| 66 |
+
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
| 67 |
+
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
| 68 |
+
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
| 69 |
+
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
| 70 |
+
model_config_kwargs = meta_data["model_config"]
|
| 71 |
+
log0(f"Building model with config: {model_config_kwargs}")
|
| 72 |
+
model_config = GPTConfig(**model_config_kwargs)
|
| 73 |
+
with torch.device("meta"):
|
| 74 |
+
model = GPT(model_config)
|
| 75 |
+
# Load the model state
|
| 76 |
+
model.to_empty(device=device)
|
| 77 |
+
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
| 78 |
+
model.load_state_dict(model_data, strict=True, assign=True)
|
| 79 |
+
# Put the model in the right training phase / mode
|
| 80 |
+
if phase == "eval":
|
| 81 |
+
model.eval()
|
| 82 |
+
else:
|
| 83 |
+
model.train()
|
| 84 |
+
# Load the Tokenizer
|
| 85 |
+
tokenizer = get_tokenizer()
|
| 86 |
+
# Sanity check: compatibility between model and tokenizer
|
| 87 |
+
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
| 88 |
+
return model, tokenizer, meta_data
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def find_largest_model(checkpoint_dir):
|
| 92 |
+
# attempt to guess the model tag: take the biggest model available
|
| 93 |
+
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
| 94 |
+
if not model_tags:
|
| 95 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 96 |
+
# 1) normally all model tags are of the form d<number>, try that first:
|
| 97 |
+
candidates = []
|
| 98 |
+
for model_tag in model_tags:
|
| 99 |
+
match = re.match(r"d(\d+)", model_tag)
|
| 100 |
+
if match:
|
| 101 |
+
model_depth = int(match.group(1))
|
| 102 |
+
candidates.append((model_depth, model_tag))
|
| 103 |
+
if candidates:
|
| 104 |
+
candidates.sort(key=lambda x: x[0], reverse=True)
|
| 105 |
+
return candidates[0][1]
|
| 106 |
+
# 2) if that failed, take the most recently updated model:
|
| 107 |
+
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
| 108 |
+
return model_tags[0]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def find_last_step(checkpoint_dir):
|
| 112 |
+
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
| 113 |
+
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
| 114 |
+
if not checkpoint_files:
|
| 115 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 116 |
+
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
| 117 |
+
return last_step
|
| 118 |
+
|
| 119 |
+
# -----------------------------------------------------------------------------
|
| 120 |
+
# convenience functions that take into account nanochat's directory structure
|
| 121 |
+
|
| 122 |
+
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
| 123 |
+
if model_tag is None:
|
| 124 |
+
# guess the model tag by defaulting to the largest model
|
| 125 |
+
model_tag = find_largest_model(checkpoints_dir)
|
| 126 |
+
log0(f"No model tag provided, guessing model tag: {model_tag}")
|
| 127 |
+
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
| 128 |
+
if step is None:
|
| 129 |
+
# guess the step by defaulting to the last step
|
| 130 |
+
step = find_last_step(checkpoint_dir)
|
| 131 |
+
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
|
| 132 |
+
# build the model
|
| 133 |
+
log0(f"Loading model from {checkpoint_dir} with step {step}")
|
| 134 |
+
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
| 135 |
+
return model, tokenizer, meta_data
|
| 136 |
+
|
| 137 |
+
def load_model(source, *args, **kwargs):
|
| 138 |
+
model_dir = {
|
| 139 |
+
"base": "base_checkpoints",
|
| 140 |
+
"mid": "mid_checkpoints",
|
| 141 |
+
"sft": "chatsft_checkpoints",
|
| 142 |
+
"rl": "chatrl_checkpoints",
|
| 143 |
+
}[source]
|
| 144 |
+
base_dir = get_base_dir()
|
| 145 |
+
checkpoints_dir = os.path.join(base_dir, model_dir)
|
| 146 |
+
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
nanochat/common.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common utilities for nanochat.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
|
| 11 |
+
class ColoredFormatter(logging.Formatter):
|
| 12 |
+
"""Custom formatter that adds colors to log messages."""
|
| 13 |
+
# ANSI color codes
|
| 14 |
+
COLORS = {
|
| 15 |
+
'DEBUG': '\033[36m', # Cyan
|
| 16 |
+
'INFO': '\033[32m', # Green
|
| 17 |
+
'WARNING': '\033[33m', # Yellow
|
| 18 |
+
'ERROR': '\033[31m', # Red
|
| 19 |
+
'CRITICAL': '\033[35m', # Magenta
|
| 20 |
+
}
|
| 21 |
+
RESET = '\033[0m'
|
| 22 |
+
BOLD = '\033[1m'
|
| 23 |
+
def format(self, record):
|
| 24 |
+
# Add color to the level name
|
| 25 |
+
levelname = record.levelname
|
| 26 |
+
if levelname in self.COLORS:
|
| 27 |
+
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
| 28 |
+
# Format the message
|
| 29 |
+
message = super().format(record)
|
| 30 |
+
# Add color to specific parts of the message
|
| 31 |
+
if levelname == 'INFO':
|
| 32 |
+
# Highlight numbers and percentages
|
| 33 |
+
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
|
| 34 |
+
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
| 35 |
+
return message
|
| 36 |
+
|
| 37 |
+
def setup_default_logging():
|
| 38 |
+
handler = logging.StreamHandler()
|
| 39 |
+
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 40 |
+
logging.basicConfig(
|
| 41 |
+
level=logging.INFO,
|
| 42 |
+
handlers=[handler]
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
setup_default_logging()
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
def get_base_dir():
|
| 49 |
+
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
| 50 |
+
if os.environ.get("NANOCHAT_BASE_DIR"):
|
| 51 |
+
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
|
| 52 |
+
else:
|
| 53 |
+
home_dir = os.path.expanduser("~")
|
| 54 |
+
cache_dir = os.path.join(home_dir, ".cache")
|
| 55 |
+
nanochat_dir = os.path.join(cache_dir, "nanochat")
|
| 56 |
+
os.makedirs(nanochat_dir, exist_ok=True)
|
| 57 |
+
return nanochat_dir
|
| 58 |
+
|
| 59 |
+
def print0(s="",**kwargs):
|
| 60 |
+
ddp_rank = int(os.environ.get('RANK', 0))
|
| 61 |
+
if ddp_rank == 0:
|
| 62 |
+
print(s, **kwargs)
|
| 63 |
+
|
| 64 |
+
def print_banner():
|
| 65 |
+
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
| 66 |
+
banner = """
|
| 67 |
+
█████ █████
|
| 68 |
+
░░███ ░░███
|
| 69 |
+
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
| 70 |
+
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░
|
| 71 |
+
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
| 72 |
+
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
| 73 |
+
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████
|
| 74 |
+
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
| 75 |
+
"""
|
| 76 |
+
print0(banner)
|
| 77 |
+
|
| 78 |
+
def is_ddp():
|
| 79 |
+
# TODO is there a proper way
|
| 80 |
+
return int(os.environ.get('RANK', -1)) != -1
|
| 81 |
+
|
| 82 |
+
def get_dist_info():
|
| 83 |
+
if is_ddp():
|
| 84 |
+
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
| 85 |
+
ddp_rank = int(os.environ['RANK'])
|
| 86 |
+
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
| 87 |
+
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
| 88 |
+
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
| 89 |
+
else:
|
| 90 |
+
return False, 0, 0, 1
|
| 91 |
+
|
| 92 |
+
def compute_init():
|
| 93 |
+
"""Basic initialization that we keep doing over and over, so make common."""
|
| 94 |
+
|
| 95 |
+
# CUDA is currently required
|
| 96 |
+
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
| 97 |
+
|
| 98 |
+
# Reproducibility
|
| 99 |
+
torch.manual_seed(42)
|
| 100 |
+
torch.cuda.manual_seed(42)
|
| 101 |
+
# skipping full reproducibility for now, possibly investigate slowdown later
|
| 102 |
+
# torch.use_deterministic_algorithms(True)
|
| 103 |
+
# torch.backends.cudnn.deterministic = True
|
| 104 |
+
# torch.backends.cudnn.benchmark = False
|
| 105 |
+
|
| 106 |
+
# Precision
|
| 107 |
+
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
| 108 |
+
|
| 109 |
+
# Distributed setup: Distributed Data Parallel (DDP), optional
|
| 110 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 111 |
+
if ddp:
|
| 112 |
+
device = torch.device("cuda", ddp_local_rank)
|
| 113 |
+
torch.cuda.set_device(device) # make "cuda" default to this device
|
| 114 |
+
dist.init_process_group(backend="nccl", device_id=device)
|
| 115 |
+
dist.barrier()
|
| 116 |
+
else:
|
| 117 |
+
device = torch.device("cuda")
|
| 118 |
+
|
| 119 |
+
if ddp_rank == 0:
|
| 120 |
+
logger.info(f"Distributed world size: {ddp_world_size}")
|
| 121 |
+
|
| 122 |
+
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
| 123 |
+
|
| 124 |
+
def compute_cleanup():
|
| 125 |
+
"""Companion function to compute_init, to clean things up before script exit"""
|
| 126 |
+
if is_ddp():
|
| 127 |
+
dist.destroy_process_group()
|
| 128 |
+
|
| 129 |
+
class DummyWandb:
|
| 130 |
+
"""Useful if we wish to not use wandb but have all the same signatures"""
|
| 131 |
+
def __init__(self):
|
| 132 |
+
pass
|
| 133 |
+
def log(self, *args, **kwargs):
|
| 134 |
+
pass
|
| 135 |
+
def finish(self):
|
| 136 |
+
pass
|
nanochat/configurator.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Poor Man's Configurator. Probably a terrible idea. Example usage:
|
| 3 |
+
$ python train.py config/override_file.py --batch_size=32
|
| 4 |
+
this will first run config/override_file.py, then override batch_size to 32
|
| 5 |
+
|
| 6 |
+
The code in this file will be run as follows from e.g. train.py:
|
| 7 |
+
>>> exec(open('configurator.py').read())
|
| 8 |
+
|
| 9 |
+
So it's not a Python module, it's just shuttling this code away from train.py
|
| 10 |
+
The code in this script then overrides the globals()
|
| 11 |
+
|
| 12 |
+
I know people are not going to love this, I just really dislike configuration
|
| 13 |
+
complexity and having to prepend config. to every single variable. If someone
|
| 14 |
+
comes up with a better simple Python solution I am all ears.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
from ast import literal_eval
|
| 20 |
+
|
| 21 |
+
def print0(s="",**kwargs):
|
| 22 |
+
ddp_rank = int(os.environ.get('RANK', 0))
|
| 23 |
+
if ddp_rank == 0:
|
| 24 |
+
print(s, **kwargs)
|
| 25 |
+
|
| 26 |
+
for arg in sys.argv[1:]:
|
| 27 |
+
if '=' not in arg:
|
| 28 |
+
# assume it's the name of a config file
|
| 29 |
+
assert not arg.startswith('--')
|
| 30 |
+
config_file = arg
|
| 31 |
+
print0(f"Overriding config with {config_file}:")
|
| 32 |
+
with open(config_file) as f:
|
| 33 |
+
print0(f.read())
|
| 34 |
+
exec(open(config_file).read())
|
| 35 |
+
else:
|
| 36 |
+
# assume it's a --key=value argument
|
| 37 |
+
assert arg.startswith('--')
|
| 38 |
+
key, val = arg.split('=')
|
| 39 |
+
key = key[2:]
|
| 40 |
+
if key in globals():
|
| 41 |
+
try:
|
| 42 |
+
# attempt to eval it it (e.g. if bool, number, or etc)
|
| 43 |
+
attempt = literal_eval(val)
|
| 44 |
+
except (SyntaxError, ValueError):
|
| 45 |
+
# if that goes wrong, just use the string
|
| 46 |
+
attempt = val
|
| 47 |
+
# ensure the types match ok
|
| 48 |
+
if globals()[key] is not None:
|
| 49 |
+
attempt_type = type(attempt)
|
| 50 |
+
default_type = type(globals()[key])
|
| 51 |
+
assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
|
| 52 |
+
# cross fingers
|
| 53 |
+
print0(f"Overriding: {key} = {attempt}")
|
| 54 |
+
globals()[key] = attempt
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"Unknown config key: {key}")
|
nanochat/core_eval.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions for evaluating the CORE metric, as described in the DCLM paper.
|
| 3 |
+
https://arxiv.org/abs/2406.11794
|
| 4 |
+
|
| 5 |
+
TODOs:
|
| 6 |
+
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
|
| 7 |
+
"""
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
from jinja2 import Template
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
# Prompt rendering utilities
|
| 16 |
+
|
| 17 |
+
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
|
| 18 |
+
"""Render complete prompts for a multiple choice question"""
|
| 19 |
+
template_str = """
|
| 20 |
+
{%- for example in fewshot_examples -%}
|
| 21 |
+
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
|
| 22 |
+
|
| 23 |
+
{% endfor -%}
|
| 24 |
+
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
|
| 25 |
+
template = Template(template_str)
|
| 26 |
+
fewshot_examples = fewshot_examples or []
|
| 27 |
+
context = {
|
| 28 |
+
'fewshot_examples': fewshot_examples,
|
| 29 |
+
'continuation_delimiter': continuation_delimiter,
|
| 30 |
+
'item': item
|
| 31 |
+
}
|
| 32 |
+
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
|
| 33 |
+
return prompts
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
|
| 37 |
+
"""Render complete prompts for a schema question"""
|
| 38 |
+
template_str = """
|
| 39 |
+
{%- for example in fewshot_examples -%}
|
| 40 |
+
{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
|
| 41 |
+
|
| 42 |
+
{% endfor -%}
|
| 43 |
+
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
|
| 44 |
+
template = Template(template_str)
|
| 45 |
+
fewshot_examples = fewshot_examples or []
|
| 46 |
+
context = {
|
| 47 |
+
'fewshot_examples': fewshot_examples,
|
| 48 |
+
'continuation_delimiter': continuation_delimiter,
|
| 49 |
+
'item': item
|
| 50 |
+
}
|
| 51 |
+
prompts = [template.render(context=context_option, **context)
|
| 52 |
+
for context_option in item['context_options']]
|
| 53 |
+
return prompts
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
|
| 57 |
+
"""
|
| 58 |
+
Render complete prompt for a language modeling task.
|
| 59 |
+
Notice that we manually trim the context in the template,
|
| 60 |
+
which in some datasets seems to have trailing whitespace (which we don't want).
|
| 61 |
+
"""
|
| 62 |
+
template_str = """
|
| 63 |
+
{%- for example in fewshot_examples -%}
|
| 64 |
+
{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
|
| 65 |
+
|
| 66 |
+
{% endfor -%}
|
| 67 |
+
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
|
| 68 |
+
template = Template(template_str)
|
| 69 |
+
fewshot_examples = fewshot_examples or []
|
| 70 |
+
context = {
|
| 71 |
+
'fewshot_examples': fewshot_examples,
|
| 72 |
+
'continuation_delimiter': continuation_delimiter,
|
| 73 |
+
'item': item
|
| 74 |
+
}
|
| 75 |
+
# Return two prompts: without and with the continuation
|
| 76 |
+
prompt_without = template.render(include_continuation=False, **context)
|
| 77 |
+
prompt_with = template.render(include_continuation=True, **context)
|
| 78 |
+
# Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
|
| 79 |
+
# Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
|
| 80 |
+
# token in prompt_with), meaning we don't get a nice and clean prefix in the token space
|
| 81 |
+
# to detect the final continuation. Tokenizers...
|
| 82 |
+
prompt_without = prompt_without.strip()
|
| 83 |
+
return [prompt_without, prompt_with]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def find_common_length(token_sequences, direction='left'):
|
| 87 |
+
"""
|
| 88 |
+
Find the length of the common prefix or suffix across token sequences
|
| 89 |
+
- direction: 'left' for prefix, 'right' for suffix
|
| 90 |
+
"""
|
| 91 |
+
min_len = min(len(seq) for seq in token_sequences)
|
| 92 |
+
indices = {
|
| 93 |
+
'left': range(min_len),
|
| 94 |
+
'right': range(-1, -min_len-1, -1)
|
| 95 |
+
}[direction]
|
| 96 |
+
# Find the first position where the token sequences differ
|
| 97 |
+
for i, idx in enumerate(indices):
|
| 98 |
+
token = token_sequences[0][idx]
|
| 99 |
+
if not all(seq[idx] == token for seq in token_sequences):
|
| 100 |
+
return i
|
| 101 |
+
return min_len
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def stack_sequences(tokens, pad_token_id):
|
| 105 |
+
"""Stack up a list of token sequences, pad to longest on the right"""
|
| 106 |
+
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
|
| 107 |
+
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
|
| 108 |
+
for i, x in enumerate(tokens):
|
| 109 |
+
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
|
| 110 |
+
return input_ids
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def batch_sequences_mc(tokenizer, prompts):
|
| 114 |
+
# In multiple choice, contexts are the same but the continuation is different (common prefix)
|
| 115 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 116 |
+
# figure out the start and end of each continuation
|
| 117 |
+
answer_start_idx = find_common_length(tokens, direction='left')
|
| 118 |
+
start_indices = [answer_start_idx] * len(prompts)
|
| 119 |
+
end_indices = [len(x) for x in tokens]
|
| 120 |
+
return tokens, start_indices, end_indices
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def batch_sequences_schema(tokenizer, prompts):
|
| 124 |
+
# In schema tasks, contexts vary but continuation is the same (common suffix)
|
| 125 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 126 |
+
# figure out the start and end of each context
|
| 127 |
+
suffix_length = find_common_length(tokens, direction='right')
|
| 128 |
+
end_indices = [len(x) for x in tokens]
|
| 129 |
+
start_indices = [ei - suffix_length for ei in end_indices]
|
| 130 |
+
return tokens, start_indices, end_indices
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def batch_sequences_lm(tokenizer, prompts):
|
| 134 |
+
# In LM tasks, we have two prompts: without and with continuation
|
| 135 |
+
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
| 136 |
+
tokens_without, tokens_with = tokens
|
| 137 |
+
start_idx, end_idx = len(tokens_without), len(tokens_with)
|
| 138 |
+
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
|
| 139 |
+
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
|
| 140 |
+
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
|
| 141 |
+
return [tokens_with], [start_idx], [end_idx]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@torch.no_grad()
|
| 145 |
+
def forward_model(model, input_ids):
|
| 146 |
+
"""
|
| 147 |
+
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
|
| 148 |
+
The last column of losses is set to nan because we don't have autoregressive targets there.
|
| 149 |
+
"""
|
| 150 |
+
batch_size, seq_len = input_ids.size()
|
| 151 |
+
outputs = model(input_ids)
|
| 152 |
+
# Roll the tensor to the left by one position to get the (autoregressive) target ids
|
| 153 |
+
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
|
| 154 |
+
# Calculate cross entropy at all positions
|
| 155 |
+
losses = torch.nn.functional.cross_entropy(
|
| 156 |
+
outputs.view(batch_size * seq_len, -1),
|
| 157 |
+
target_ids.view(batch_size * seq_len),
|
| 158 |
+
reduction='none'
|
| 159 |
+
).view(batch_size, seq_len)
|
| 160 |
+
# Set the last column to be nan because there is no autoregressive loss there
|
| 161 |
+
losses[:, -1] = float('nan')
|
| 162 |
+
# Get the argmax predictions at each position
|
| 163 |
+
predictions = outputs.argmax(dim=-1)
|
| 164 |
+
return losses, predictions
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
| 169 |
+
"""Evaluate a single example, return True if correct, False otherwise"""
|
| 170 |
+
item = data[idx]
|
| 171 |
+
task_type = task_meta['task_type']
|
| 172 |
+
num_fewshot = task_meta['num_fewshot']
|
| 173 |
+
continuation_delimiter = task_meta['continuation_delimiter']
|
| 174 |
+
|
| 175 |
+
# Sample few-shot examples (excluding current item)
|
| 176 |
+
fewshot_examples = []
|
| 177 |
+
if num_fewshot > 0:
|
| 178 |
+
rng = random.Random(1234 + idx)
|
| 179 |
+
available_indices = [i for i in range(len(data)) if i != idx]
|
| 180 |
+
fewshot_indices = rng.sample(available_indices, num_fewshot)
|
| 181 |
+
fewshot_examples = [data[i] for i in fewshot_indices]
|
| 182 |
+
|
| 183 |
+
# Render prompts and batch sequences based on task type
|
| 184 |
+
if task_type == 'multiple_choice':
|
| 185 |
+
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
|
| 186 |
+
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
|
| 187 |
+
elif task_type == 'schema':
|
| 188 |
+
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
|
| 189 |
+
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
|
| 190 |
+
elif task_type == 'language_modeling':
|
| 191 |
+
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
|
| 192 |
+
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Unsupported task type: {task_type}")
|
| 195 |
+
|
| 196 |
+
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
|
| 197 |
+
# In these cases, we have to truncate sequences to max length and adjust the indices
|
| 198 |
+
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
|
| 199 |
+
max_tokens = model.max_seq_len
|
| 200 |
+
new_tokens, new_start_idxs, new_end_idxs = [], [], []
|
| 201 |
+
for t, s, e in zip(tokens, start_idxs, end_idxs):
|
| 202 |
+
if len(t) > max_tokens:
|
| 203 |
+
num_to_crop = len(t) - max_tokens
|
| 204 |
+
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
|
| 205 |
+
new_start_idxs.append(s - num_to_crop) # shift the indices down
|
| 206 |
+
new_end_idxs.append(e - num_to_crop)
|
| 207 |
+
assert s - num_to_crop >= 0, "this should never happen right?"
|
| 208 |
+
assert e - num_to_crop >= 0, "this should never happen right?"
|
| 209 |
+
else:
|
| 210 |
+
new_tokens.append(t) # keep unchanged
|
| 211 |
+
new_start_idxs.append(s)
|
| 212 |
+
new_end_idxs.append(e)
|
| 213 |
+
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
|
| 214 |
+
|
| 215 |
+
# Stack up all the sequences into a batch
|
| 216 |
+
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
|
| 217 |
+
input_ids = stack_sequences(tokens, pad_token_id)
|
| 218 |
+
input_ids = input_ids.to(device)
|
| 219 |
+
|
| 220 |
+
# Forward the model, get the autoregressive loss and argmax prediction at each token
|
| 221 |
+
losses, predictions = forward_model(model, input_ids)
|
| 222 |
+
|
| 223 |
+
# See if the losses/predictions come out correctly
|
| 224 |
+
if task_type == 'language_modeling':
|
| 225 |
+
# language modeling task is currently always batch size 1
|
| 226 |
+
si = start_idxs[0]
|
| 227 |
+
ei = end_idxs[0]
|
| 228 |
+
# predictions[i] predict input_ids[i+1] autoregressively
|
| 229 |
+
predicted_tokens = predictions[0, si-1:ei-1]
|
| 230 |
+
actual_tokens = input_ids[0, si:ei]
|
| 231 |
+
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
| 232 |
+
elif task_type in ['multiple_choice', 'schema']:
|
| 233 |
+
# For MC/schema: find the option with lowest average loss
|
| 234 |
+
mean_losses = [losses[i, si-1:ei-1].mean().item()
|
| 235 |
+
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
| 236 |
+
pred_idx = mean_losses.index(min(mean_losses))
|
| 237 |
+
is_correct = pred_idx == item['gold']
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(f"Unsupported task type: {task_type}")
|
| 240 |
+
|
| 241 |
+
return is_correct
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def evaluate_task(model, tokenizer, data, device, task_meta):
|
| 245 |
+
"""
|
| 246 |
+
This function is responsible for evaluating one task across many examples.
|
| 247 |
+
It also handles dispatch to all processes if the script is run with torchrun.
|
| 248 |
+
"""
|
| 249 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 250 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 251 |
+
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
|
| 252 |
+
# stride the examples to each rank
|
| 253 |
+
for idx in range(rank, len(data), world_size):
|
| 254 |
+
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
|
| 255 |
+
correct[idx] = float(is_correct)
|
| 256 |
+
# sync results across all the processes if running distributed
|
| 257 |
+
if world_size > 1:
|
| 258 |
+
dist.barrier()
|
| 259 |
+
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
| 260 |
+
# compute the mean
|
| 261 |
+
mean_correct = correct.mean().item()
|
| 262 |
+
return mean_correct
|
nanochat/dataloader.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from nanochat.common import get_dist_info
|
| 6 |
+
from nanochat.dataset import parquets_iter_batched
|
| 7 |
+
from nanochat.tokenizer import get_tokenizer
|
| 8 |
+
|
| 9 |
+
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
|
| 10 |
+
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
| 11 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 12 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 13 |
+
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
| 14 |
+
# get the tokenizer and the bos token
|
| 15 |
+
tokenizer = get_tokenizer()
|
| 16 |
+
bos_token = tokenizer.get_bos_token_id()
|
| 17 |
+
# scratch buffer holds the tokens for one iteration
|
| 18 |
+
token_buffer = deque() # we stream tokens on the right and pop from the left
|
| 19 |
+
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
| 20 |
+
|
| 21 |
+
# infinite iterator over document batches
|
| 22 |
+
def document_batches():
|
| 23 |
+
while True:
|
| 24 |
+
# batch will iterate in group size of the parquet files, usually e.g. 1024 rows
|
| 25 |
+
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
|
| 26 |
+
# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
|
| 27 |
+
for i in range(0, len(batch), tokenizer_batch_size):
|
| 28 |
+
yield batch[i:i+tokenizer_batch_size]
|
| 29 |
+
batches = document_batches()
|
| 30 |
+
|
| 31 |
+
batch_index = 0
|
| 32 |
+
while True:
|
| 33 |
+
# Accumulate enough tokens for one iteration before yielding.
|
| 34 |
+
while len(token_buffer) < needed_tokens:
|
| 35 |
+
doc_batch = next(batches)
|
| 36 |
+
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
| 37 |
+
for tokens in token_lists:
|
| 38 |
+
token_buffer.extend(tokens)
|
| 39 |
+
batch_index += 1
|
| 40 |
+
# Move tokens from the deque into the scratch buffer
|
| 41 |
+
for i in range(needed_tokens):
|
| 42 |
+
scratch[i] = token_buffer.popleft()
|
| 43 |
+
# Create the inputs/targets as 1D tensors
|
| 44 |
+
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
| 45 |
+
targets_cpu = scratch[1:]
|
| 46 |
+
# Reshape to 2D and move to GPU async
|
| 47 |
+
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
| 48 |
+
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
| 49 |
+
yield inputs, targets
|
nanochat/dataset.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The base/pretraining dataset is a set of parquet files.
|
| 3 |
+
This file contains utilities for:
|
| 4 |
+
- iterating over the parquet files and yielding documents from it
|
| 5 |
+
- download the files on demand if they are not on disk
|
| 6 |
+
|
| 7 |
+
For details of how the dataset was prepared, see `repackage_data_reference.py`.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import time
|
| 13 |
+
import requests
|
| 14 |
+
import pyarrow.parquet as pq
|
| 15 |
+
from multiprocessing import Pool
|
| 16 |
+
|
| 17 |
+
from nanochat.common import get_base_dir
|
| 18 |
+
|
| 19 |
+
# -----------------------------------------------------------------------------
|
| 20 |
+
# The specifics of the current pretraining dataset
|
| 21 |
+
|
| 22 |
+
# The URL on the internet where the data is hosted and downloaded from on demand
|
| 23 |
+
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
|
| 24 |
+
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
|
| 25 |
+
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
|
| 26 |
+
base_dir = get_base_dir()
|
| 27 |
+
DATA_DIR = os.path.join(base_dir, "base_data")
|
| 28 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
# -----------------------------------------------------------------------------
|
| 31 |
+
# These functions are useful utilities to other modules, can/should be imported
|
| 32 |
+
|
| 33 |
+
def list_parquet_files(data_dir=None):
|
| 34 |
+
""" Looks into a data dir and returns full paths to all parquet files. """
|
| 35 |
+
data_dir = DATA_DIR if data_dir is None else data_dir
|
| 36 |
+
parquet_files = sorted([
|
| 37 |
+
f for f in os.listdir(data_dir)
|
| 38 |
+
if f.endswith('.parquet') and not f.endswith('.tmp')
|
| 39 |
+
])
|
| 40 |
+
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
|
| 41 |
+
return parquet_paths
|
| 42 |
+
|
| 43 |
+
def parquets_iter_batched(split, start=0, step=1):
|
| 44 |
+
"""
|
| 45 |
+
Iterate through the dataset, in batches of underlying row_groups for efficiency.
|
| 46 |
+
- split can be "train" or "val". the last parquet file will be val.
|
| 47 |
+
- start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
|
| 48 |
+
"""
|
| 49 |
+
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
| 50 |
+
parquet_paths = list_parquet_files()
|
| 51 |
+
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
| 52 |
+
for filepath in parquet_paths:
|
| 53 |
+
pf = pq.ParquetFile(filepath)
|
| 54 |
+
for rg_idx in range(start, pf.num_row_groups, step):
|
| 55 |
+
rg = pf.read_row_group(rg_idx)
|
| 56 |
+
texts = rg.column('text').to_pylist()
|
| 57 |
+
yield texts
|
| 58 |
+
|
| 59 |
+
# -----------------------------------------------------------------------------
|
| 60 |
+
def download_single_file(index):
|
| 61 |
+
""" Downloads a single file index, with some backoff """
|
| 62 |
+
|
| 63 |
+
# Construct the local filepath for this file and skip if it already exists
|
| 64 |
+
filename = index_to_filename(index)
|
| 65 |
+
filepath = os.path.join(DATA_DIR, filename)
|
| 66 |
+
if os.path.exists(filepath):
|
| 67 |
+
print(f"Skipping {filepath} (already exists)")
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
# Construct the remote URL for this file
|
| 71 |
+
url = f"{BASE_URL}/{filename}"
|
| 72 |
+
print(f"Downloading {filename}...")
|
| 73 |
+
|
| 74 |
+
# Download with retries
|
| 75 |
+
max_attempts = 5
|
| 76 |
+
for attempt in range(1, max_attempts + 1):
|
| 77 |
+
try:
|
| 78 |
+
response = requests.get(url, stream=True, timeout=30)
|
| 79 |
+
response.raise_for_status()
|
| 80 |
+
# Write to temporary file first
|
| 81 |
+
temp_path = filepath + f".tmp"
|
| 82 |
+
with open(temp_path, 'wb') as f:
|
| 83 |
+
for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
|
| 84 |
+
if chunk:
|
| 85 |
+
f.write(chunk)
|
| 86 |
+
# Move temp file to final location
|
| 87 |
+
os.rename(temp_path, filepath)
|
| 88 |
+
print(f"Successfully downloaded {filename}")
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
except (requests.RequestException, IOError) as e:
|
| 92 |
+
print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
| 93 |
+
# Clean up any partial files
|
| 94 |
+
for path in [filepath + f".tmp", filepath]:
|
| 95 |
+
if os.path.exists(path):
|
| 96 |
+
try:
|
| 97 |
+
os.remove(path)
|
| 98 |
+
except:
|
| 99 |
+
pass
|
| 100 |
+
# Try a few times with exponential backoff: 2^attempt seconds
|
| 101 |
+
if attempt < max_attempts:
|
| 102 |
+
wait_time = 2 ** attempt
|
| 103 |
+
print(f"Waiting {wait_time} seconds before retry...")
|
| 104 |
+
time.sleep(wait_time)
|
| 105 |
+
else:
|
| 106 |
+
print(f"Failed to download {filename} after {max_attempts} attempts")
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
|
| 114 |
+
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
|
| 115 |
+
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
|
| 119 |
+
ids_to_download = list(range(num))
|
| 120 |
+
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
|
| 121 |
+
print(f"Target directory: {DATA_DIR}")
|
| 122 |
+
print()
|
| 123 |
+
with Pool(processes=args.num_workers) as pool:
|
| 124 |
+
results = pool.map(download_single_file, ids_to_download)
|
| 125 |
+
|
| 126 |
+
# Report results
|
| 127 |
+
successful = sum(1 for success in results if success)
|
| 128 |
+
print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
|
nanochat/engine.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engine for efficient inference of our models.
|
| 3 |
+
|
| 4 |
+
Everything works around token sequences:
|
| 5 |
+
- The user can send token sequences to the engine
|
| 6 |
+
- The engine returns the next token
|
| 7 |
+
|
| 8 |
+
Notes:
|
| 9 |
+
- The engine knows nothing about tokenization, it's purely token id sequences.
|
| 10 |
+
|
| 11 |
+
The whole thing is made as efficient as possible.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import signal
|
| 17 |
+
import warnings
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from collections import deque
|
| 20 |
+
from nanochat.common import compute_init
|
| 21 |
+
from nanochat.checkpoint_manager import load_model
|
| 22 |
+
|
| 23 |
+
# -----------------------------------------------------------------------------
|
| 24 |
+
# Calculator tool helpers
|
| 25 |
+
@contextmanager
|
| 26 |
+
def timeout(duration, formula):
|
| 27 |
+
def timeout_handler(signum, frame):
|
| 28 |
+
raise Exception(f"'{formula}': timed out after {duration} seconds")
|
| 29 |
+
|
| 30 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 31 |
+
signal.alarm(duration)
|
| 32 |
+
yield
|
| 33 |
+
signal.alarm(0)
|
| 34 |
+
|
| 35 |
+
def eval_with_timeout(formula, max_time=3):
|
| 36 |
+
try:
|
| 37 |
+
with timeout(max_time, formula):
|
| 38 |
+
with warnings.catch_warnings():
|
| 39 |
+
warnings.simplefilter("ignore", SyntaxWarning)
|
| 40 |
+
return eval(formula)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
signal.alarm(0)
|
| 43 |
+
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
def use_calculator(expr):
|
| 47 |
+
"""Evaluate a math expression safely."""
|
| 48 |
+
expr = expr.replace(",", "")
|
| 49 |
+
if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
|
| 50 |
+
return None
|
| 51 |
+
if "**" in expr: # for now disallow power operator, could be very expensive
|
| 52 |
+
return None
|
| 53 |
+
return eval_with_timeout(expr)
|
| 54 |
+
|
| 55 |
+
# -----------------------------------------------------------------------------
|
| 56 |
+
class KVCache:
|
| 57 |
+
"""
|
| 58 |
+
Works hand-in-hand with the GPT model to maintain the KV cache.
|
| 59 |
+
Note that the .pos advances automatically after the last layer of the Transformer inserts.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
|
| 63 |
+
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
|
| 64 |
+
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
| 65 |
+
self.kv_cache = None
|
| 66 |
+
self.pos = 0 # current position in time in the cache
|
| 67 |
+
|
| 68 |
+
def reset(self):
|
| 69 |
+
self.pos = 0
|
| 70 |
+
|
| 71 |
+
def get_pos(self):
|
| 72 |
+
return self.pos
|
| 73 |
+
|
| 74 |
+
def prefill(self, other):
|
| 75 |
+
"""
|
| 76 |
+
Prefill given another KV cache. Optionally expand along batch dim.
|
| 77 |
+
This is used when we do batch 1 prefill and then want to generate
|
| 78 |
+
multiple samples in parallel from there.
|
| 79 |
+
"""
|
| 80 |
+
# 1) validate the shapes
|
| 81 |
+
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
|
| 82 |
+
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
|
| 83 |
+
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
|
| 84 |
+
if ix in [0, 1, 3, 5]:
|
| 85 |
+
# num_layers, batch_size, num_heads, head_dim must match
|
| 86 |
+
assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}"
|
| 87 |
+
elif ix == 2:
|
| 88 |
+
# batch_size can be expanded
|
| 89 |
+
assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
|
| 90 |
+
elif ix == 4:
|
| 91 |
+
# seq_len: self must be longer than other
|
| 92 |
+
assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
|
| 93 |
+
# 2) initialize the cache
|
| 94 |
+
dtype, device = other.kv_cache.dtype, other.kv_cache.device
|
| 95 |
+
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
|
| 96 |
+
# 3) copy the data over
|
| 97 |
+
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
|
| 98 |
+
# 4) update the pos
|
| 99 |
+
self.pos = other.pos
|
| 100 |
+
|
| 101 |
+
def insert_kv(self, layer_idx, k, v):
|
| 102 |
+
# Lazy initialize the cache here because we need to know the dtype/device
|
| 103 |
+
if self.kv_cache is None:
|
| 104 |
+
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
|
| 105 |
+
# Insert new keys/values to the cache and return the full cache so far
|
| 106 |
+
B, H, T_add, D = k.size()
|
| 107 |
+
t0, t1 = self.pos, self.pos + T_add
|
| 108 |
+
# Dynamically grow the cache if needed
|
| 109 |
+
if t1 > self.kv_cache.size(4):
|
| 110 |
+
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
| 111 |
+
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
| 112 |
+
current_shape = list(self.kv_cache.shape)
|
| 113 |
+
current_shape[4] = t_needed
|
| 114 |
+
self.kv_cache.resize_(current_shape)
|
| 115 |
+
# Insert k, v into the cache
|
| 116 |
+
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
| 117 |
+
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
| 118 |
+
# Return the full cached keys/values up to current position (as a view)
|
| 119 |
+
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
|
| 120 |
+
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
|
| 121 |
+
# Increment pos after the last layer of the Transformer processes
|
| 122 |
+
if layer_idx == self.kv_cache.size(0) - 1:
|
| 123 |
+
self.pos = t1
|
| 124 |
+
return key_view, value_view
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# -----------------------------------------------------------------------------
|
| 128 |
+
@torch.inference_mode()
|
| 129 |
+
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
| 130 |
+
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
|
| 131 |
+
assert temperature >= 0.0, "temperature must be non-negative"
|
| 132 |
+
if temperature == 0.0:
|
| 133 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
| 134 |
+
if top_k is not None:
|
| 135 |
+
k = min(top_k, logits.size(-1))
|
| 136 |
+
vals, idx = torch.topk(logits, k, dim=-1)
|
| 137 |
+
vals = vals / temperature
|
| 138 |
+
probs = F.softmax(vals, dim=-1)
|
| 139 |
+
choice = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 140 |
+
return idx.gather(1, choice)
|
| 141 |
+
else:
|
| 142 |
+
logits = logits / temperature
|
| 143 |
+
probs = F.softmax(logits, dim=-1)
|
| 144 |
+
return torch.multinomial(probs, num_samples=1, generator=rng)
|
| 145 |
+
|
| 146 |
+
# -----------------------------------------------------------------------------
|
| 147 |
+
|
| 148 |
+
class RowState:
|
| 149 |
+
# Per-row state tracking during generation
|
| 150 |
+
def __init__(self, current_tokens=None):
|
| 151 |
+
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
| 152 |
+
self.forced_tokens = deque() # Queue of tokens to force inject
|
| 153 |
+
self.in_python_block = False # Whether we are inside a python block
|
| 154 |
+
self.python_expr_tokens = [] # Tokens of the current python expression
|
| 155 |
+
self.completed = False # Whether this row has completed generation
|
| 156 |
+
|
| 157 |
+
class Engine:
|
| 158 |
+
|
| 159 |
+
def __init__(self, model, tokenizer):
|
| 160 |
+
self.model = model
|
| 161 |
+
self.tokenizer = tokenizer # needed for tool use
|
| 162 |
+
|
| 163 |
+
@torch.inference_mode()
|
| 164 |
+
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
| 165 |
+
"""Same as generate, but does single prefill and then clones the KV cache."""
|
| 166 |
+
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
| 167 |
+
device = self.model.get_device()
|
| 168 |
+
rng = torch.Generator(device=device)
|
| 169 |
+
rng.manual_seed(seed)
|
| 170 |
+
|
| 171 |
+
# Get the special tokens we need to coordinate the tool use state machine
|
| 172 |
+
get_special = lambda s: self.tokenizer.encode_special(s)
|
| 173 |
+
python_start = get_special("<|python_start|>")
|
| 174 |
+
python_end = get_special("<|python_end|>")
|
| 175 |
+
output_start = get_special("<|output_start|>")
|
| 176 |
+
output_end = get_special("<|output_end|>")
|
| 177 |
+
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
| 178 |
+
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
| 179 |
+
|
| 180 |
+
# 1) Run a batch 1 prefill of the prompt tokens
|
| 181 |
+
m = self.model.config
|
| 182 |
+
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
| 183 |
+
kv_cache_prefill = KVCache(
|
| 184 |
+
batch_size=1,
|
| 185 |
+
seq_len=len(tokens),
|
| 186 |
+
**kv_model_kwargs,
|
| 187 |
+
)
|
| 188 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 189 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
| 190 |
+
logits = logits[:, -1, :]
|
| 191 |
+
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
| 192 |
+
sampled_tokens = next_ids[:, 0].tolist()
|
| 193 |
+
|
| 194 |
+
# 2) Replicate the KV cache for each sample/row
|
| 195 |
+
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
| 196 |
+
kv_cache_decode = KVCache(
|
| 197 |
+
batch_size=num_samples,
|
| 198 |
+
seq_len=kv_length_hint,
|
| 199 |
+
**kv_model_kwargs,
|
| 200 |
+
)
|
| 201 |
+
kv_cache_decode.prefill(kv_cache_prefill)
|
| 202 |
+
del kv_cache_prefill # no need to keep this memory around
|
| 203 |
+
|
| 204 |
+
# 3) Initialize states for each sample
|
| 205 |
+
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
| 206 |
+
|
| 207 |
+
# 4) Main generation loop
|
| 208 |
+
num_generated = 0
|
| 209 |
+
first_iteration = True
|
| 210 |
+
while True:
|
| 211 |
+
# Stop condition: we've reached max tokens
|
| 212 |
+
if max_tokens is not None and num_generated >= max_tokens:
|
| 213 |
+
break
|
| 214 |
+
# Stop condition: all rows are completed
|
| 215 |
+
if all(state.completed for state in row_states):
|
| 216 |
+
break
|
| 217 |
+
|
| 218 |
+
# Get sampled tokens - either from prefill or from forward pass
|
| 219 |
+
if first_iteration:
|
| 220 |
+
# Use the tokens we already sampled from prefill
|
| 221 |
+
sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
|
| 222 |
+
# TODO: we should sample a token for each row instead of broadcasting
|
| 223 |
+
first_iteration = False
|
| 224 |
+
else:
|
| 225 |
+
# Forward the model and get the next token for each row
|
| 226 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
|
| 227 |
+
logits = logits[:, -1, :] # (B, vocab_size) at last time step
|
| 228 |
+
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
| 229 |
+
sampled_tokens = next_ids[:, 0].tolist()
|
| 230 |
+
|
| 231 |
+
# Process each row: choose the next token, update state, optional tool use
|
| 232 |
+
token_column = [] # contains the next token id along each row
|
| 233 |
+
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
| 234 |
+
for i, state in enumerate(row_states):
|
| 235 |
+
# Select the next token in this row
|
| 236 |
+
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
| 237 |
+
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
| 238 |
+
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
| 239 |
+
token_column.append(next_token)
|
| 240 |
+
# Update the state of this row to include the next token
|
| 241 |
+
state.current_tokens.append(next_token)
|
| 242 |
+
# On <|assistant_end|> or <|bos|>, mark the row as completed
|
| 243 |
+
if next_token == assistant_end or next_token == bos:
|
| 244 |
+
state.completed = True
|
| 245 |
+
# Handle tool logic
|
| 246 |
+
if next_token == python_start:
|
| 247 |
+
state.in_python_block = True
|
| 248 |
+
state.python_expr_tokens = []
|
| 249 |
+
elif next_token == python_end and state.in_python_block:
|
| 250 |
+
state.in_python_block = False
|
| 251 |
+
if state.python_expr_tokens:
|
| 252 |
+
expr = self.tokenizer.decode(state.python_expr_tokens)
|
| 253 |
+
result = use_calculator(expr)
|
| 254 |
+
if result is not None:
|
| 255 |
+
result_tokens = self.tokenizer.encode(str(result))
|
| 256 |
+
state.forced_tokens.append(output_start)
|
| 257 |
+
state.forced_tokens.extend(result_tokens)
|
| 258 |
+
state.forced_tokens.append(output_end)
|
| 259 |
+
state.python_expr_tokens = []
|
| 260 |
+
elif state.in_python_block:
|
| 261 |
+
state.python_expr_tokens.append(next_token)
|
| 262 |
+
|
| 263 |
+
# Yield the token column
|
| 264 |
+
yield token_column, token_masks
|
| 265 |
+
num_generated += 1
|
| 266 |
+
# Prepare ids for next iteration
|
| 267 |
+
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
| 268 |
+
|
| 269 |
+
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
| 270 |
+
"""
|
| 271 |
+
Non-streaming batch generation that just returns the final token sequences.
|
| 272 |
+
Returns a list of token sequences (list of lists of ints).
|
| 273 |
+
Terminal tokens (assistant_end, bos) are not included in the results.
|
| 274 |
+
"""
|
| 275 |
+
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
| 276 |
+
bos = self.tokenizer.get_bos_token_id()
|
| 277 |
+
results = [tokens.copy() for _ in range(num_samples)]
|
| 278 |
+
masks = [[0] * len(tokens) for _ in range(num_samples)]
|
| 279 |
+
completed = [False] * num_samples
|
| 280 |
+
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
|
| 281 |
+
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
|
| 282 |
+
if not completed[i]:
|
| 283 |
+
if token == assistant_end or token == bos:
|
| 284 |
+
completed[i] = True
|
| 285 |
+
else:
|
| 286 |
+
results[i].append(token)
|
| 287 |
+
masks[i].append(mask)
|
| 288 |
+
# Stop if all rows are completed
|
| 289 |
+
if all(completed):
|
| 290 |
+
break
|
| 291 |
+
return results, masks
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
"""
|
| 296 |
+
Quick inline test to make sure that the naive/slow model.generate function
|
| 297 |
+
is equivalent to the faster Engine.generate function here.
|
| 298 |
+
"""
|
| 299 |
+
import time
|
| 300 |
+
# init compute
|
| 301 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
| 302 |
+
# load the model and tokenizer
|
| 303 |
+
model, tokenizer, meta = load_model("base", device, phase="eval")
|
| 304 |
+
bos_token_id = tokenizer.get_bos_token_id()
|
| 305 |
+
# common hyperparameters
|
| 306 |
+
kwargs = dict(max_tokens=64, temperature=0.0)
|
| 307 |
+
# set the starting prompt
|
| 308 |
+
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
|
| 309 |
+
# generate the reference sequence using the model.generate() function
|
| 310 |
+
generated_tokens = []
|
| 311 |
+
torch.cuda.synchronize()
|
| 312 |
+
t0 = time.time()
|
| 313 |
+
stream = model.generate(prompt_tokens, **kwargs)
|
| 314 |
+
for token in stream:
|
| 315 |
+
generated_tokens.append(token)
|
| 316 |
+
chunk = tokenizer.decode([token])
|
| 317 |
+
print(chunk, end="", flush=True)
|
| 318 |
+
print()
|
| 319 |
+
torch.cuda.synchronize()
|
| 320 |
+
t1 = time.time()
|
| 321 |
+
print(f"Reference time: {t1 - t0:.2f}s")
|
| 322 |
+
reference_ids = generated_tokens
|
| 323 |
+
# generate tokens with Engine
|
| 324 |
+
generated_tokens = []
|
| 325 |
+
engine = Engine(model, tokenizer)
|
| 326 |
+
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
| 327 |
+
torch.cuda.synchronize()
|
| 328 |
+
t0 = time.time()
|
| 329 |
+
for token_column, token_masks in stream:
|
| 330 |
+
token = token_column[0] # only print out the first row
|
| 331 |
+
generated_tokens.append(token)
|
| 332 |
+
chunk = tokenizer.decode([token])
|
| 333 |
+
print(chunk, end="", flush=True)
|
| 334 |
+
print()
|
| 335 |
+
torch.cuda.synchronize()
|
| 336 |
+
t1 = time.time()
|
| 337 |
+
print(f"Engine time: {t1 - t0:.2f}s")
|
| 338 |
+
# compare the two sequences
|
| 339 |
+
for i in range(len(reference_ids)):
|
| 340 |
+
if reference_ids[i] != generated_tokens[i]:
|
| 341 |
+
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
|
| 342 |
+
break
|
| 343 |
+
print(f"Match: {reference_ids == generated_tokens}")
|
nanochat/execution.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sandboxed execution utilities for running Python code that comes out of an LLM.
|
| 3 |
+
Adapted from OpenAI HumanEval code:
|
| 4 |
+
https://github.com/openai/human-eval/blob/master/human_eval/execution.py
|
| 5 |
+
|
| 6 |
+
What is covered:
|
| 7 |
+
- Each execution runs in its own process (can be killed if it hangs or crashes)
|
| 8 |
+
- Execution is limited by a timeout to stop infinite loops
|
| 9 |
+
- Memory limits are enforced by default (256MB)
|
| 10 |
+
- stdout and stderr are captured and returned
|
| 11 |
+
- Code runs in a temporary directory that is deleted afterwards
|
| 12 |
+
- Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
|
| 13 |
+
|
| 14 |
+
What is not covered:
|
| 15 |
+
- Not a true security sandbox
|
| 16 |
+
- Network access is not blocked (e.g. sockets could be opened)
|
| 17 |
+
- Python's dynamic features (e.g. ctypes) could bypass restrictions
|
| 18 |
+
- No kernel-level isolation (no seccomp, no containers, no virtualization)
|
| 19 |
+
|
| 20 |
+
Overall this sandbox is good for evaluation of generated code and protects against
|
| 21 |
+
accidental destructive behavior, but it is not safe against malicious adversarial code.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import contextlib
|
| 25 |
+
import faulthandler
|
| 26 |
+
import io
|
| 27 |
+
import multiprocessing
|
| 28 |
+
import os
|
| 29 |
+
import platform
|
| 30 |
+
import signal
|
| 31 |
+
import tempfile
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from typing import Optional
|
| 34 |
+
|
| 35 |
+
# -----------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class ExecutionResult:
|
| 39 |
+
"""Result of executing Python code in a sandbox."""
|
| 40 |
+
success: bool
|
| 41 |
+
stdout: str
|
| 42 |
+
stderr: str
|
| 43 |
+
error: Optional[str] = None
|
| 44 |
+
timeout: bool = False
|
| 45 |
+
memory_exceeded: bool = False
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
parts = []
|
| 49 |
+
parts.append(f"ExecutionResult(success={self.success}")
|
| 50 |
+
if self.timeout:
|
| 51 |
+
parts.append(", timeout=True")
|
| 52 |
+
if self.memory_exceeded:
|
| 53 |
+
parts.append(", memory_exceeded=True")
|
| 54 |
+
if self.error:
|
| 55 |
+
parts.append(f", error={self.error!r}")
|
| 56 |
+
if self.stdout:
|
| 57 |
+
parts.append(f", stdout={self.stdout!r}")
|
| 58 |
+
if self.stderr:
|
| 59 |
+
parts.append(f", stderr={self.stderr!r}")
|
| 60 |
+
parts.append(")")
|
| 61 |
+
return "".join(parts)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@contextlib.contextmanager
|
| 65 |
+
def time_limit(seconds: float):
|
| 66 |
+
def signal_handler(signum, frame):
|
| 67 |
+
raise TimeoutException("Timed out!")
|
| 68 |
+
|
| 69 |
+
signal.setitimer(signal.ITIMER_REAL, seconds)
|
| 70 |
+
signal.signal(signal.SIGALRM, signal_handler)
|
| 71 |
+
try:
|
| 72 |
+
yield
|
| 73 |
+
finally:
|
| 74 |
+
signal.setitimer(signal.ITIMER_REAL, 0)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@contextlib.contextmanager
|
| 78 |
+
def capture_io():
|
| 79 |
+
"""Capture stdout and stderr, and disable stdin."""
|
| 80 |
+
stdout_capture = io.StringIO()
|
| 81 |
+
stderr_capture = io.StringIO()
|
| 82 |
+
stdin_block = WriteOnlyStringIO()
|
| 83 |
+
with contextlib.redirect_stdout(stdout_capture):
|
| 84 |
+
with contextlib.redirect_stderr(stderr_capture):
|
| 85 |
+
with redirect_stdin(stdin_block):
|
| 86 |
+
yield stdout_capture, stderr_capture
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@contextlib.contextmanager
|
| 90 |
+
def create_tempdir():
|
| 91 |
+
with tempfile.TemporaryDirectory() as dirname:
|
| 92 |
+
with chdir(dirname):
|
| 93 |
+
yield dirname
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class TimeoutException(Exception):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class WriteOnlyStringIO(io.StringIO):
|
| 101 |
+
"""StringIO that throws an exception when it's read from"""
|
| 102 |
+
|
| 103 |
+
def read(self, *args, **kwargs):
|
| 104 |
+
raise IOError
|
| 105 |
+
|
| 106 |
+
def readline(self, *args, **kwargs):
|
| 107 |
+
raise IOError
|
| 108 |
+
|
| 109 |
+
def readlines(self, *args, **kwargs):
|
| 110 |
+
raise IOError
|
| 111 |
+
|
| 112 |
+
def readable(self, *args, **kwargs):
|
| 113 |
+
"""Returns True if the IO object can be read."""
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class redirect_stdin(contextlib._RedirectStream): # type: ignore
|
| 118 |
+
_stream = "stdin"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@contextlib.contextmanager
|
| 122 |
+
def chdir(root):
|
| 123 |
+
if root == ".":
|
| 124 |
+
yield
|
| 125 |
+
return
|
| 126 |
+
cwd = os.getcwd()
|
| 127 |
+
os.chdir(root)
|
| 128 |
+
try:
|
| 129 |
+
yield
|
| 130 |
+
except BaseException as exc:
|
| 131 |
+
raise exc
|
| 132 |
+
finally:
|
| 133 |
+
os.chdir(cwd)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
| 137 |
+
"""
|
| 138 |
+
This disables various destructive functions and prevents the generated code
|
| 139 |
+
from interfering with the test (e.g. fork bomb, killing other processes,
|
| 140 |
+
removing filesystem files, etc.)
|
| 141 |
+
|
| 142 |
+
WARNING
|
| 143 |
+
This function is NOT a security sandbox. Untrusted code, including, model-
|
| 144 |
+
generated code, should not be blindly executed outside of one. See the
|
| 145 |
+
Codex paper for more information about OpenAI's code sandbox, and proceed
|
| 146 |
+
with caution.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
if maximum_memory_bytes is not None:
|
| 150 |
+
import resource
|
| 151 |
+
|
| 152 |
+
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
| 153 |
+
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
| 154 |
+
if not platform.uname().system == "Darwin":
|
| 155 |
+
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
| 156 |
+
|
| 157 |
+
faulthandler.disable()
|
| 158 |
+
|
| 159 |
+
import builtins
|
| 160 |
+
|
| 161 |
+
builtins.exit = None
|
| 162 |
+
builtins.quit = None
|
| 163 |
+
|
| 164 |
+
import os
|
| 165 |
+
|
| 166 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 167 |
+
|
| 168 |
+
os.kill = None
|
| 169 |
+
os.system = None
|
| 170 |
+
os.putenv = None
|
| 171 |
+
os.remove = None
|
| 172 |
+
os.removedirs = None
|
| 173 |
+
os.rmdir = None
|
| 174 |
+
os.fchdir = None
|
| 175 |
+
os.setuid = None
|
| 176 |
+
os.fork = None
|
| 177 |
+
os.forkpty = None
|
| 178 |
+
os.killpg = None
|
| 179 |
+
os.rename = None
|
| 180 |
+
os.renames = None
|
| 181 |
+
os.truncate = None
|
| 182 |
+
os.replace = None
|
| 183 |
+
os.unlink = None
|
| 184 |
+
os.fchmod = None
|
| 185 |
+
os.fchown = None
|
| 186 |
+
os.chmod = None
|
| 187 |
+
os.chown = None
|
| 188 |
+
os.chroot = None
|
| 189 |
+
os.fchdir = None
|
| 190 |
+
os.lchflags = None
|
| 191 |
+
os.lchmod = None
|
| 192 |
+
os.lchown = None
|
| 193 |
+
os.getcwd = None
|
| 194 |
+
os.chdir = None
|
| 195 |
+
|
| 196 |
+
import shutil
|
| 197 |
+
|
| 198 |
+
shutil.rmtree = None
|
| 199 |
+
shutil.move = None
|
| 200 |
+
shutil.chown = None
|
| 201 |
+
|
| 202 |
+
import subprocess
|
| 203 |
+
|
| 204 |
+
subprocess.Popen = None # type: ignore
|
| 205 |
+
|
| 206 |
+
__builtins__["help"] = None
|
| 207 |
+
|
| 208 |
+
import sys
|
| 209 |
+
|
| 210 |
+
sys.modules["ipdb"] = None
|
| 211 |
+
sys.modules["joblib"] = None
|
| 212 |
+
sys.modules["resource"] = None
|
| 213 |
+
sys.modules["psutil"] = None
|
| 214 |
+
sys.modules["tkinter"] = None
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
|
| 218 |
+
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
|
| 219 |
+
with create_tempdir():
|
| 220 |
+
|
| 221 |
+
# These system calls are needed when cleaning up tempdir.
|
| 222 |
+
import os
|
| 223 |
+
import shutil
|
| 224 |
+
|
| 225 |
+
rmtree = shutil.rmtree
|
| 226 |
+
rmdir = os.rmdir
|
| 227 |
+
chdir = os.chdir
|
| 228 |
+
|
| 229 |
+
# Disable functionalities that can make destructive changes to the test.
|
| 230 |
+
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
| 231 |
+
|
| 232 |
+
# Default to failure
|
| 233 |
+
result_dict.update({
|
| 234 |
+
"success": False,
|
| 235 |
+
"stdout": "",
|
| 236 |
+
"stderr": "",
|
| 237 |
+
"timeout": False,
|
| 238 |
+
"memory_exceeded": False,
|
| 239 |
+
"error": None,
|
| 240 |
+
})
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
exec_globals = {}
|
| 244 |
+
with capture_io() as (stdout_capture, stderr_capture):
|
| 245 |
+
with time_limit(timeout):
|
| 246 |
+
# WARNING
|
| 247 |
+
# This program exists to execute untrusted model-generated code. Although
|
| 248 |
+
# it is highly unlikely that model-generated code will do something overtly
|
| 249 |
+
# malicious in response to this test suite, model-generated code may act
|
| 250 |
+
# destructively due to a lack of model capability or alignment.
|
| 251 |
+
# Users are strongly encouraged to sandbox this evaluation suite so that it
|
| 252 |
+
# does not perform destructive actions on their host or network. For more
|
| 253 |
+
# information on how OpenAI sandboxes its code, see the accompanying paper.
|
| 254 |
+
# Once you have read this disclaimer and taken appropriate precautions,
|
| 255 |
+
# uncomment the following line and proceed at your own risk:
|
| 256 |
+
exec(code, exec_globals)
|
| 257 |
+
|
| 258 |
+
result_dict.update({
|
| 259 |
+
"success": True,
|
| 260 |
+
"stdout": stdout_capture.getvalue(),
|
| 261 |
+
"stderr": stderr_capture.getvalue(),
|
| 262 |
+
})
|
| 263 |
+
|
| 264 |
+
except TimeoutException:
|
| 265 |
+
result_dict.update({
|
| 266 |
+
"timeout": True,
|
| 267 |
+
"error": "Execution timed out",
|
| 268 |
+
})
|
| 269 |
+
|
| 270 |
+
except MemoryError as e:
|
| 271 |
+
result_dict.update({
|
| 272 |
+
"memory_exceeded": True,
|
| 273 |
+
"error": f"Memory limit exceeded: {e}",
|
| 274 |
+
})
|
| 275 |
+
|
| 276 |
+
except BaseException as e:
|
| 277 |
+
result_dict.update({
|
| 278 |
+
"error": f"{type(e).__name__}: {e}",
|
| 279 |
+
})
|
| 280 |
+
|
| 281 |
+
# Needed for cleaning up.
|
| 282 |
+
shutil.rmtree = rmtree
|
| 283 |
+
os.rmdir = rmdir
|
| 284 |
+
os.chdir = chdir
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def execute_code(
|
| 288 |
+
code: str,
|
| 289 |
+
timeout: float = 5.0, # 5 seconds default
|
| 290 |
+
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
|
| 291 |
+
) -> ExecutionResult:
|
| 292 |
+
"""
|
| 293 |
+
Execute Python code in a sandboxed environment.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
code: Python code to execute as a string
|
| 297 |
+
timeout: Maximum execution time in seconds (default: 5.0)
|
| 298 |
+
maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
ExecutionResult with success status, stdout/stderr, and error information
|
| 302 |
+
|
| 303 |
+
Example:
|
| 304 |
+
>>> result = execute_code("print('hello world')")
|
| 305 |
+
>>> result.success
|
| 306 |
+
True
|
| 307 |
+
>>> result.stdout
|
| 308 |
+
'hello world\\n'
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
manager = multiprocessing.Manager()
|
| 312 |
+
result_dict = manager.dict()
|
| 313 |
+
|
| 314 |
+
p = multiprocessing.Process(
|
| 315 |
+
target=_unsafe_execute,
|
| 316 |
+
args=(code, timeout, maximum_memory_bytes, result_dict)
|
| 317 |
+
)
|
| 318 |
+
p.start()
|
| 319 |
+
p.join(timeout=timeout + 1)
|
| 320 |
+
|
| 321 |
+
if p.is_alive():
|
| 322 |
+
p.kill()
|
| 323 |
+
return ExecutionResult(
|
| 324 |
+
success=False,
|
| 325 |
+
stdout="",
|
| 326 |
+
stderr="",
|
| 327 |
+
error="Execution timed out (process killed)",
|
| 328 |
+
timeout=True,
|
| 329 |
+
memory_exceeded=False,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if not result_dict:
|
| 333 |
+
return ExecutionResult(
|
| 334 |
+
success=False,
|
| 335 |
+
stdout="",
|
| 336 |
+
stderr="",
|
| 337 |
+
error="Execution failed (no result returned)",
|
| 338 |
+
timeout=True,
|
| 339 |
+
memory_exceeded=False,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
return ExecutionResult(
|
| 343 |
+
success=result_dict["success"],
|
| 344 |
+
stdout=result_dict["stdout"],
|
| 345 |
+
stderr=result_dict["stderr"],
|
| 346 |
+
error=result_dict["error"],
|
| 347 |
+
timeout=result_dict["timeout"],
|
| 348 |
+
memory_exceeded=result_dict["memory_exceeded"],
|
| 349 |
+
)
|
| 350 |
+
|
nanochat/gpt.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT model (rewrite, a lot simpler)
|
| 3 |
+
Notable features:
|
| 4 |
+
- rotary embeddings (and no positional embeddings)
|
| 5 |
+
- QK norm
|
| 6 |
+
- untied weights for token embedding and lm_head
|
| 7 |
+
- relu^2 activation in MLP
|
| 8 |
+
- norm after token embedding
|
| 9 |
+
- no learnable params in rmsnorm
|
| 10 |
+
- no bias in linear layers
|
| 11 |
+
- Multi-Query Attention (MQA) support for more efficient inference
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
from functools import partial
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from nanochat.common import get_dist_info, print0
|
| 23 |
+
from nanochat.muon import Muon, DistMuon
|
| 24 |
+
from nanochat.adamw import DistAdamW
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class GPTConfig:
|
| 28 |
+
sequence_len: int = 1024
|
| 29 |
+
vocab_size: int = 50304
|
| 30 |
+
n_layer: int = 12
|
| 31 |
+
n_head: int = 6 # number of query heads
|
| 32 |
+
n_kv_head: int = 6 # number of key/value heads (MQA)
|
| 33 |
+
n_embd: int = 768
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def norm(x):
|
| 37 |
+
# Purely functional rmsnorm with no learnable params
|
| 38 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def apply_rotary_emb(x, cos, sin):
|
| 42 |
+
assert x.ndim == 4 # multihead attention
|
| 43 |
+
d = x.shape[3] // 2
|
| 44 |
+
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
| 45 |
+
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
| 46 |
+
y2 = x1 * (-sin) + x2 * cos
|
| 47 |
+
out = torch.cat([y1, y2], 3) # re-assemble
|
| 48 |
+
out = out.to(x.dtype) # ensure input/output dtypes match
|
| 49 |
+
return out
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def repeat_kv(x, n_rep):
|
| 53 |
+
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
| 54 |
+
if n_rep == 1:
|
| 55 |
+
return x
|
| 56 |
+
bs, n_kv_heads, slen, head_dim = x.shape
|
| 57 |
+
return (
|
| 58 |
+
x[:, :, None, :, :]
|
| 59 |
+
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
| 60 |
+
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class CausalSelfAttention(nn.Module):
|
| 65 |
+
def __init__(self, config, layer_idx):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.layer_idx = layer_idx
|
| 68 |
+
self.n_head = config.n_head
|
| 69 |
+
self.n_kv_head = config.n_kv_head
|
| 70 |
+
self.n_embd = config.n_embd
|
| 71 |
+
self.head_dim = self.n_embd // self.n_head
|
| 72 |
+
assert self.n_embd % self.n_head == 0
|
| 73 |
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
| 74 |
+
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
| 75 |
+
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 76 |
+
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 77 |
+
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, cos_sin, kv_cache):
|
| 80 |
+
B, T, C = x.size()
|
| 81 |
+
|
| 82 |
+
# Project the input to get queries, keys, and values
|
| 83 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
| 84 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 85 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 86 |
+
|
| 87 |
+
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
| 88 |
+
cos, sin = cos_sin
|
| 89 |
+
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
| 90 |
+
q, k = norm(q), norm(k) # QK norm
|
| 91 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
| 92 |
+
|
| 93 |
+
# Apply KV cache: insert current k,v into cache, get the full view so far
|
| 94 |
+
if kv_cache is not None:
|
| 95 |
+
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
| 96 |
+
Tq = q.size(2) # number of queries in this forward pass
|
| 97 |
+
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
| 98 |
+
|
| 99 |
+
# Apply MQA: replicate the key/value heads for each query head
|
| 100 |
+
nrep = self.n_head // self.n_kv_head
|
| 101 |
+
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
|
| 102 |
+
|
| 103 |
+
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
| 104 |
+
if kv_cache is None or Tq == Tk:
|
| 105 |
+
# During training (no KV cache), attend as usual with causal attention
|
| 106 |
+
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
| 107 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| 108 |
+
elif Tq == 1:
|
| 109 |
+
# During inference but with a single query in this forward pass:
|
| 110 |
+
# The query has to attend to all the keys/values in the cache
|
| 111 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
| 112 |
+
else:
|
| 113 |
+
# During inference AND we have a chunk of queries in this forward pass:
|
| 114 |
+
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
| 115 |
+
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
| 116 |
+
prefix_len = Tk - Tq
|
| 117 |
+
if prefix_len > 0: # can't be negative but could be zero
|
| 118 |
+
attn_mask[:, :prefix_len] = True
|
| 119 |
+
# Then, causal attention within this chunk
|
| 120 |
+
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
| 121 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 122 |
+
|
| 123 |
+
# Re-assemble the heads side by side and project back to residual stream
|
| 124 |
+
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
| 125 |
+
y = self.c_proj(y)
|
| 126 |
+
return y
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class MLP(nn.Module):
|
| 130 |
+
def __init__(self, config):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
| 133 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
x = self.c_fc(x)
|
| 137 |
+
x = F.relu(x).square()
|
| 138 |
+
x = self.c_proj(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Block(nn.Module):
|
| 143 |
+
def __init__(self, config, layer_idx):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.attn = CausalSelfAttention(config, layer_idx)
|
| 146 |
+
self.mlp = MLP(config)
|
| 147 |
+
|
| 148 |
+
def forward(self, x, cos_sin, kv_cache):
|
| 149 |
+
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
| 150 |
+
x = x + self.mlp(norm(x))
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class GPT(nn.Module):
|
| 155 |
+
def __init__(self, config):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.config = config
|
| 158 |
+
self.transformer = nn.ModuleDict({
|
| 159 |
+
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
| 160 |
+
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
| 161 |
+
})
|
| 162 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 163 |
+
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
| 164 |
+
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
| 165 |
+
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
| 166 |
+
# In the future we can dynamically grow the cache, for now it's fine.
|
| 167 |
+
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
| 168 |
+
head_dim = config.n_embd // config.n_head
|
| 169 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 170 |
+
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
| 171 |
+
self.register_buffer("sin", sin, persistent=False)
|
| 172 |
+
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
| 173 |
+
self.transformer.wte.to(dtype=torch.bfloat16)
|
| 174 |
+
|
| 175 |
+
def init_weights(self):
|
| 176 |
+
self.apply(self._init_weights)
|
| 177 |
+
# zero out classifier weights
|
| 178 |
+
torch.nn.init.zeros_(self.lm_head.weight)
|
| 179 |
+
# zero out c_proj weights in all blocks
|
| 180 |
+
for block in self.transformer.h:
|
| 181 |
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
| 182 |
+
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
| 183 |
+
# init the rotary embeddings
|
| 184 |
+
head_dim = self.config.n_embd // self.config.n_head
|
| 185 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 186 |
+
self.cos, self.sin = cos, sin
|
| 187 |
+
|
| 188 |
+
def _init_weights(self, module):
|
| 189 |
+
if isinstance(module, nn.Linear):
|
| 190 |
+
# https://arxiv.org/pdf/2310.17813
|
| 191 |
+
fan_out = module.weight.size(0)
|
| 192 |
+
fan_in = module.weight.size(1)
|
| 193 |
+
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
| 194 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 195 |
+
if module.bias is not None:
|
| 196 |
+
torch.nn.init.zeros_(module.bias)
|
| 197 |
+
elif isinstance(module, nn.Embedding):
|
| 198 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
| 199 |
+
|
| 200 |
+
# TODO: bump base theta more, e.g. 100K is more common more recently
|
| 201 |
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
| 202 |
+
# autodetect the device from model embeddings
|
| 203 |
+
if device is None:
|
| 204 |
+
device = self.transformer.wte.weight.device
|
| 205 |
+
# stride the channels
|
| 206 |
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
| 207 |
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
| 208 |
+
# stride the time steps
|
| 209 |
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
| 210 |
+
# calculate the rotation frequencies at each (time, channel) pair
|
| 211 |
+
freqs = torch.outer(t, inv_freq)
|
| 212 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 213 |
+
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
| 214 |
+
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
| 215 |
+
return cos, sin
|
| 216 |
+
|
| 217 |
+
def get_device(self):
|
| 218 |
+
return self.transformer.wte.weight.device
|
| 219 |
+
|
| 220 |
+
def estimate_flops(self):
|
| 221 |
+
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
| 222 |
+
nparams = sum(p.numel() for p in self.parameters())
|
| 223 |
+
nparams_embedding = self.transformer.wte.weight.numel()
|
| 224 |
+
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
| 225 |
+
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
| 226 |
+
return num_flops_per_token
|
| 227 |
+
|
| 228 |
+
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
| 229 |
+
model_dim = self.config.n_embd
|
| 230 |
+
ddp, rank, local_rank, world_size = get_dist_info()
|
| 231 |
+
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
| 232 |
+
matrix_params = list(self.transformer.h.parameters())
|
| 233 |
+
embedding_params = list(self.transformer.wte.parameters())
|
| 234 |
+
lm_head_params = list(self.lm_head.parameters())
|
| 235 |
+
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
|
| 236 |
+
# Create the AdamW optimizer for the embedding and lm_head
|
| 237 |
+
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
| 238 |
+
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
| 239 |
+
if rank == 0:
|
| 240 |
+
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
| 241 |
+
adam_groups = [
|
| 242 |
+
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
| 243 |
+
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
| 244 |
+
]
|
| 245 |
+
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
| 246 |
+
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
| 247 |
+
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
| 248 |
+
# Create the Muon optimizer for the linear layers
|
| 249 |
+
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
| 250 |
+
MuonFactory = DistMuon if ddp else Muon
|
| 251 |
+
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
| 252 |
+
# Combine them the two optimizers into one list
|
| 253 |
+
optimizers = [adamw_optimizer, muon_optimizer]
|
| 254 |
+
for opt in optimizers:
|
| 255 |
+
for group in opt.param_groups:
|
| 256 |
+
group["initial_lr"] = group["lr"]
|
| 257 |
+
return optimizers
|
| 258 |
+
|
| 259 |
+
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
| 260 |
+
B, T = idx.size()
|
| 261 |
+
|
| 262 |
+
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
| 263 |
+
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
| 264 |
+
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
| 265 |
+
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
| 266 |
+
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
| 267 |
+
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
| 268 |
+
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
| 269 |
+
|
| 270 |
+
# Forward the trunk of the Transformer
|
| 271 |
+
x = self.transformer.wte(idx)
|
| 272 |
+
x = norm(x)
|
| 273 |
+
for block in self.transformer.h:
|
| 274 |
+
x = block(x, cos_sin, kv_cache)
|
| 275 |
+
x = norm(x)
|
| 276 |
+
|
| 277 |
+
# Forward the lm_head (compute logits)
|
| 278 |
+
softcap = 15
|
| 279 |
+
if targets is not None:
|
| 280 |
+
# training mode: compute and return the loss
|
| 281 |
+
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
| 282 |
+
logits = self.lm_head(x)
|
| 283 |
+
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
| 284 |
+
logits = logits.float() # use tf32/fp32 for logits
|
| 285 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
| 286 |
+
return loss
|
| 287 |
+
else:
|
| 288 |
+
# inference mode: compute and return the logits
|
| 289 |
+
logits = self.lm_head(x)
|
| 290 |
+
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
| 291 |
+
return logits
|
| 292 |
+
|
| 293 |
+
@torch.inference_mode()
|
| 294 |
+
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
| 295 |
+
"""
|
| 296 |
+
Naive autoregressive streaming inference.
|
| 297 |
+
To make it super simple, let's assume:
|
| 298 |
+
- batch size is 1
|
| 299 |
+
- ids and the yielded tokens are simple Python lists and ints
|
| 300 |
+
"""
|
| 301 |
+
assert isinstance(tokens, list)
|
| 302 |
+
device = self.get_device()
|
| 303 |
+
rng = None
|
| 304 |
+
if temperature > 0:
|
| 305 |
+
rng = torch.Generator(device=device)
|
| 306 |
+
rng.manual_seed(seed)
|
| 307 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
| 308 |
+
for _ in range(max_tokens):
|
| 309 |
+
logits = self.forward(ids) # (B, T, vocab_size)
|
| 310 |
+
logits = logits[:, -1, :] # (B, vocab_size)
|
| 311 |
+
if top_k is not None:
|
| 312 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 313 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 314 |
+
if temperature > 0:
|
| 315 |
+
logits = logits / temperature
|
| 316 |
+
probs = F.softmax(logits, dim=-1)
|
| 317 |
+
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 318 |
+
else:
|
| 319 |
+
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
| 320 |
+
ids = torch.cat((ids, next_ids), dim=1)
|
| 321 |
+
token = next_ids.item()
|
| 322 |
+
yield token
|
nanochat/logo.svg
ADDED
|
|
nanochat/loss_eval.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A number of functions that help with evaluating a base model.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def evaluate_bpb(model, batches, steps, token_bytes):
|
| 10 |
+
"""
|
| 11 |
+
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
| 12 |
+
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
|
| 13 |
+
apples:apples if you change the vocab size. The way this works is that instead of just
|
| 14 |
+
calculating the average loss as usual, you calculate the sum loss, and indepependently
|
| 15 |
+
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
|
| 16 |
+
the number of bytes that the target tokens represent.
|
| 17 |
+
|
| 18 |
+
The added complexity is so that:
|
| 19 |
+
1) All "normal" tokens are normalized by the length of the token in bytes
|
| 20 |
+
2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
|
| 21 |
+
3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
|
| 22 |
+
|
| 23 |
+
In addition to evaluate_loss, we need the token_bytes tensor:
|
| 24 |
+
It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
|
| 25 |
+
each token id, or 0 if the token is to not be counted (e.g. special tokens).
|
| 26 |
+
"""
|
| 27 |
+
# record the losses
|
| 28 |
+
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
|
| 29 |
+
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
| 30 |
+
batch_iter = iter(batches)
|
| 31 |
+
for _ in range(steps):
|
| 32 |
+
x, y = next(batch_iter)
|
| 33 |
+
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
| 34 |
+
loss2d = loss2d.view(-1) # flatten
|
| 35 |
+
y = y.view(-1) # flatten
|
| 36 |
+
if (y < 0).any():
|
| 37 |
+
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
| 38 |
+
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
| 39 |
+
valid = y >= 0
|
| 40 |
+
y_safe = torch.where(valid, y, torch.zeros_like(y))
|
| 41 |
+
# map valid targets to their byte length; ignored targets contribute 0 bytes
|
| 42 |
+
num_bytes2d = torch.where(
|
| 43 |
+
valid,
|
| 44 |
+
token_bytes[y_safe],
|
| 45 |
+
torch.zeros_like(y, dtype=token_bytes.dtype)
|
| 46 |
+
)
|
| 47 |
+
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
| 48 |
+
total_bytes += num_bytes2d.sum()
|
| 49 |
+
else:
|
| 50 |
+
# fast path: no ignored targets, safe to index directly
|
| 51 |
+
num_bytes2d = token_bytes[y]
|
| 52 |
+
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
| 53 |
+
total_bytes += num_bytes2d.sum()
|
| 54 |
+
# sum reduce across all ranks
|
| 55 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 56 |
+
if world_size > 1:
|
| 57 |
+
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
|
| 58 |
+
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
|
| 59 |
+
# move both to cpu, calculate bpb and return
|
| 60 |
+
total_nats = total_nats.item()
|
| 61 |
+
total_bytes = total_bytes.item()
|
| 62 |
+
bpb = total_nats / (math.log(2) * total_bytes)
|
| 63 |
+
return bpb
|
nanochat/muon.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Muon optimizer from Keller et al.
|
| 3 |
+
Also a lot of borrowing of ideas from modded-nanogpt.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
|
| 9 |
+
@torch.compile
|
| 10 |
+
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
| 11 |
+
"""
|
| 12 |
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
| 13 |
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
| 14 |
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
| 15 |
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
| 16 |
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
| 17 |
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
| 18 |
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 19 |
+
"""
|
| 20 |
+
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
| 21 |
+
a, b, c = (3.4445, -4.7750, 2.0315)
|
| 22 |
+
X = G.bfloat16()
|
| 23 |
+
if G.size(-2) > G.size(-1):
|
| 24 |
+
X = X.mT
|
| 25 |
+
|
| 26 |
+
# Ensure spectral norm is at most 1
|
| 27 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 28 |
+
# Perform the NS iterations
|
| 29 |
+
for _ in range(steps):
|
| 30 |
+
A = X @ X.mT
|
| 31 |
+
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
| 32 |
+
X = a * X + B @ X
|
| 33 |
+
|
| 34 |
+
if G.size(-2) > G.size(-1):
|
| 35 |
+
X = X.mT
|
| 36 |
+
return X
|
| 37 |
+
|
| 38 |
+
class Muon(torch.optim.Optimizer):
|
| 39 |
+
"""
|
| 40 |
+
Muon - MomentUm Orthogonalized by Newton-schulz
|
| 41 |
+
|
| 42 |
+
https://kellerjordan.github.io/posts/muon/
|
| 43 |
+
|
| 44 |
+
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
| 45 |
+
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
| 46 |
+
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
| 47 |
+
the advantage that it can be stably run in bfloat16 on the GPU.
|
| 48 |
+
|
| 49 |
+
Some warnings:
|
| 50 |
+
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
| 51 |
+
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
| 52 |
+
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
| 53 |
+
|
| 54 |
+
Arguments:
|
| 55 |
+
lr: The learning rate used by the internal SGD.
|
| 56 |
+
momentum: The momentum used by the internal SGD.
|
| 57 |
+
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
| 58 |
+
ns_steps: The number of Newton-Schulz iteration steps to use.
|
| 59 |
+
"""
|
| 60 |
+
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
|
| 61 |
+
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
| 62 |
+
params: list[Tensor] = [*params]
|
| 63 |
+
param_groups = []
|
| 64 |
+
for size in {p.numel() for p in params}:
|
| 65 |
+
group = dict(params=[p for p in params if p.numel() == size])
|
| 66 |
+
param_groups.append(group)
|
| 67 |
+
super().__init__(param_groups, defaults)
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def step(self):
|
| 71 |
+
for group in self.param_groups:
|
| 72 |
+
params: list[Tensor] = group["params"]
|
| 73 |
+
for p in params:
|
| 74 |
+
g = p.grad
|
| 75 |
+
assert g is not None
|
| 76 |
+
state = self.state[p]
|
| 77 |
+
if "momentum_buffer" not in state:
|
| 78 |
+
state["momentum_buffer"] = torch.zeros_like(g)
|
| 79 |
+
buf: Tensor = state["momentum_buffer"]
|
| 80 |
+
buf.lerp_(g, 1 - group["momentum"])
|
| 81 |
+
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
| 82 |
+
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
| 83 |
+
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class DistMuon(torch.optim.Optimizer):
|
| 87 |
+
"""
|
| 88 |
+
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
|
| 89 |
+
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
| 90 |
+
- reduce_scatter(AVG) for gradient averaging
|
| 91 |
+
- all_gather to replicate updated weights
|
| 92 |
+
|
| 93 |
+
Notes:
|
| 94 |
+
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
| 95 |
+
params like embeddings or scalars.
|
| 96 |
+
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
| 97 |
+
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
| 98 |
+
consolidate states beforehand.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
params: iterable of Tensors
|
| 102 |
+
lr: learning rate
|
| 103 |
+
momentum: momentum coefficient in [0,1)
|
| 104 |
+
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
| 105 |
+
ns_steps: number of Newton–Schulz iterations for the orthogonalization
|
| 106 |
+
"""
|
| 107 |
+
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
| 108 |
+
nesterov: bool = True, ns_steps: int = 5):
|
| 109 |
+
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
| 110 |
+
params = list(params)
|
| 111 |
+
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
| 112 |
+
rank = dist.get_rank()
|
| 113 |
+
# Group all parameters by their shape
|
| 114 |
+
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
| 115 |
+
param_groups = []
|
| 116 |
+
for shape in shapes:
|
| 117 |
+
group_params = [p for p in params if p.shape == shape]
|
| 118 |
+
device, dtype = group_params[0].device, group_params[0].dtype
|
| 119 |
+
assert all(p.device == device for p in group_params)
|
| 120 |
+
assert all(p.dtype == dtype for p in group_params)
|
| 121 |
+
if rank == 0:
|
| 122 |
+
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
| 123 |
+
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
| 124 |
+
super().__init__(param_groups, defaults)
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def step(self):
|
| 128 |
+
rank = dist.get_rank()
|
| 129 |
+
world_size = dist.get_world_size()
|
| 130 |
+
|
| 131 |
+
# Ensure all grads exist
|
| 132 |
+
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
| 133 |
+
|
| 134 |
+
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
| 135 |
+
all_reduce_futures = []
|
| 136 |
+
for group in self.param_groups:
|
| 137 |
+
params = group["params"]
|
| 138 |
+
zero_buffer = group["zero_buffer"]
|
| 139 |
+
# Go through params in groups of world_size.
|
| 140 |
+
for base_i in range(0, len(params), world_size):
|
| 141 |
+
# The compute owner of each param is rank i % world_size
|
| 142 |
+
owner_idx = base_i + rank
|
| 143 |
+
# each rank stacks up its chunk of world_size params into a list
|
| 144 |
+
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
| 145 |
+
# pad rs_input with the zero buffer to complete the group
|
| 146 |
+
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
| 147 |
+
# the output buffer gets strided across the group based on the rank
|
| 148 |
+
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
| 149 |
+
# reduce scatter the gradients within this group of world_size params
|
| 150 |
+
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 151 |
+
all_reduce_futures.append(work)
|
| 152 |
+
|
| 153 |
+
# Now each rank computes the update and gathers
|
| 154 |
+
future_idx = 0
|
| 155 |
+
all_gather_futures = []
|
| 156 |
+
for group in self.param_groups:
|
| 157 |
+
params = group["params"]
|
| 158 |
+
zero_buffer = group["zero_buffer"]
|
| 159 |
+
# Go through params in groups of world_size.
|
| 160 |
+
for base_i in range(0, len(params), world_size):
|
| 161 |
+
# The compute owner of each param is rank i % world_size
|
| 162 |
+
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
| 163 |
+
# Wait for the reduce scatter to complete
|
| 164 |
+
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
| 165 |
+
future_idx += 1
|
| 166 |
+
# Owner computes the Muon update, result is in its param
|
| 167 |
+
if owner_idx < len(params):
|
| 168 |
+
p = params[owner_idx]
|
| 169 |
+
g = p.grad # now averaged across ranks
|
| 170 |
+
state = self.state[p]
|
| 171 |
+
if "momentum_buffer" not in state:
|
| 172 |
+
state["momentum_buffer"] = torch.zeros_like(g)
|
| 173 |
+
buf: Tensor = state["momentum_buffer"]
|
| 174 |
+
buf.lerp_(g, 1.0 - group["momentum"])
|
| 175 |
+
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
| 176 |
+
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
| 177 |
+
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
| 178 |
+
p.add_(g, alpha=-group["lr"] * scale)
|
| 179 |
+
# Replicate updated parameters to all ranks
|
| 180 |
+
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
| 181 |
+
ag_output = params[base_i:base_i + world_size]
|
| 182 |
+
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
| 183 |
+
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
| 184 |
+
all_gather_futures.append(work)
|
| 185 |
+
|
| 186 |
+
# Wait for all work to finish
|
| 187 |
+
torch.futures.collect_all(all_gather_futures).wait()
|
nanochat/report.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for generating training report cards. More messy code than usual, will fix.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import subprocess
|
| 9 |
+
import socket
|
| 10 |
+
import datetime
|
| 11 |
+
import platform
|
| 12 |
+
import psutil
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
def run_command(cmd):
|
| 16 |
+
"""Run a shell command and return output, or None if it fails."""
|
| 17 |
+
try:
|
| 18 |
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
|
| 19 |
+
if result.returncode == 0:
|
| 20 |
+
return result.stdout.strip()
|
| 21 |
+
return None
|
| 22 |
+
except:
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
def get_git_info():
|
| 26 |
+
"""Get current git commit, branch, and dirty status."""
|
| 27 |
+
info = {}
|
| 28 |
+
info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
|
| 29 |
+
info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
|
| 30 |
+
|
| 31 |
+
# Check if repo is dirty (has uncommitted changes)
|
| 32 |
+
status = run_command("git status --porcelain")
|
| 33 |
+
info['dirty'] = bool(status) if status is not None else False
|
| 34 |
+
|
| 35 |
+
# Get commit message
|
| 36 |
+
info['message'] = run_command("git log -1 --pretty=%B") or ""
|
| 37 |
+
info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
|
| 38 |
+
|
| 39 |
+
return info
|
| 40 |
+
|
| 41 |
+
def get_gpu_info():
|
| 42 |
+
"""Get GPU information."""
|
| 43 |
+
if not torch.cuda.is_available():
|
| 44 |
+
return {"available": False}
|
| 45 |
+
|
| 46 |
+
num_devices = torch.cuda.device_count()
|
| 47 |
+
info = {
|
| 48 |
+
"available": True,
|
| 49 |
+
"count": num_devices,
|
| 50 |
+
"names": [],
|
| 51 |
+
"memory_gb": []
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
for i in range(num_devices):
|
| 55 |
+
props = torch.cuda.get_device_properties(i)
|
| 56 |
+
info["names"].append(props.name)
|
| 57 |
+
info["memory_gb"].append(props.total_memory / (1024**3))
|
| 58 |
+
|
| 59 |
+
# Get CUDA version
|
| 60 |
+
info["cuda_version"] = torch.version.cuda or "unknown"
|
| 61 |
+
|
| 62 |
+
return info
|
| 63 |
+
|
| 64 |
+
def get_system_info():
|
| 65 |
+
"""Get system information."""
|
| 66 |
+
info = {}
|
| 67 |
+
|
| 68 |
+
# Basic system info
|
| 69 |
+
info['hostname'] = socket.gethostname()
|
| 70 |
+
info['platform'] = platform.system()
|
| 71 |
+
info['python_version'] = platform.python_version()
|
| 72 |
+
info['torch_version'] = torch.__version__
|
| 73 |
+
|
| 74 |
+
# CPU and memory
|
| 75 |
+
info['cpu_count'] = psutil.cpu_count(logical=False)
|
| 76 |
+
info['cpu_count_logical'] = psutil.cpu_count(logical=True)
|
| 77 |
+
info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
|
| 78 |
+
|
| 79 |
+
# User and environment
|
| 80 |
+
info['user'] = os.environ.get('USER', 'unknown')
|
| 81 |
+
info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
|
| 82 |
+
info['working_dir'] = os.getcwd()
|
| 83 |
+
|
| 84 |
+
return info
|
| 85 |
+
|
| 86 |
+
def estimate_cost(gpu_info, runtime_hours=None):
|
| 87 |
+
"""Estimate training cost based on GPU type and runtime."""
|
| 88 |
+
|
| 89 |
+
# Rough pricing, from Lambda Cloud
|
| 90 |
+
default_rate = 2.0
|
| 91 |
+
gpu_hourly_rates = {
|
| 92 |
+
"H100": 3.00,
|
| 93 |
+
"A100": 1.79,
|
| 94 |
+
"V100": 0.55,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
if not gpu_info.get("available"):
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
# Try to identify GPU type from name
|
| 101 |
+
hourly_rate = None
|
| 102 |
+
gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
|
| 103 |
+
for gpu_type, rate in gpu_hourly_rates.items():
|
| 104 |
+
if gpu_type in gpu_name:
|
| 105 |
+
hourly_rate = rate * gpu_info["count"]
|
| 106 |
+
break
|
| 107 |
+
|
| 108 |
+
if hourly_rate is None:
|
| 109 |
+
hourly_rate = default_rate * gpu_info["count"] # Default estimate
|
| 110 |
+
|
| 111 |
+
return {
|
| 112 |
+
"hourly_rate": hourly_rate,
|
| 113 |
+
"gpu_type": gpu_name,
|
| 114 |
+
"estimated_total": hourly_rate * runtime_hours if runtime_hours else None
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
def generate_header():
|
| 118 |
+
"""Generate the header for a training report."""
|
| 119 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 120 |
+
|
| 121 |
+
git_info = get_git_info()
|
| 122 |
+
gpu_info = get_gpu_info()
|
| 123 |
+
sys_info = get_system_info()
|
| 124 |
+
cost_info = estimate_cost(gpu_info)
|
| 125 |
+
|
| 126 |
+
header = f"""# nanochat training report
|
| 127 |
+
|
| 128 |
+
Generated: {timestamp}
|
| 129 |
+
|
| 130 |
+
## Environment
|
| 131 |
+
|
| 132 |
+
### Git Information
|
| 133 |
+
- Branch: {git_info['branch']}
|
| 134 |
+
- Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
|
| 135 |
+
- Message: {git_info['message']}
|
| 136 |
+
|
| 137 |
+
### Hardware
|
| 138 |
+
- Platform: {sys_info['platform']}
|
| 139 |
+
- CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
|
| 140 |
+
- Memory: {sys_info['memory_gb']:.1f} GB
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
if gpu_info.get("available"):
|
| 144 |
+
gpu_names = ", ".join(set(gpu_info["names"]))
|
| 145 |
+
total_vram = sum(gpu_info["memory_gb"])
|
| 146 |
+
header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
|
| 147 |
+
- GPU Memory: {total_vram:.1f} GB total
|
| 148 |
+
- CUDA Version: {gpu_info['cuda_version']}
|
| 149 |
+
"""
|
| 150 |
+
else:
|
| 151 |
+
header += "- GPUs: None available\n"
|
| 152 |
+
|
| 153 |
+
if cost_info and cost_info["hourly_rate"] > 0:
|
| 154 |
+
header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
|
| 155 |
+
|
| 156 |
+
header += f"""
|
| 157 |
+
### Software
|
| 158 |
+
- Python: {sys_info['python_version']}
|
| 159 |
+
- PyTorch: {sys_info['torch_version']}
|
| 160 |
+
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
# bloat metrics: package all of the source code and assess its weight
|
| 164 |
+
packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
|
| 165 |
+
num_chars = len(packaged)
|
| 166 |
+
num_lines = len(packaged.split('\n'))
|
| 167 |
+
num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
|
| 168 |
+
num_tokens = num_chars // 4 # assume approximately 4 chars per token
|
| 169 |
+
|
| 170 |
+
# count dependencies via uv.lock
|
| 171 |
+
uv_lock_lines = 0
|
| 172 |
+
if os.path.exists('uv.lock'):
|
| 173 |
+
with open('uv.lock', 'r') as f:
|
| 174 |
+
uv_lock_lines = len(f.readlines())
|
| 175 |
+
|
| 176 |
+
header += f"""
|
| 177 |
+
### Bloat
|
| 178 |
+
- Characters: {num_chars:,}
|
| 179 |
+
- Lines: {num_lines:,}
|
| 180 |
+
- Files: {num_files:,}
|
| 181 |
+
- Tokens (approx): {num_tokens:,}
|
| 182 |
+
- Dependencies (uv.lock lines): {uv_lock_lines:,}
|
| 183 |
+
|
| 184 |
+
"""
|
| 185 |
+
return header
|
| 186 |
+
|
| 187 |
+
# -----------------------------------------------------------------------------
|
| 188 |
+
|
| 189 |
+
def slugify(text):
|
| 190 |
+
"""Slugify a text string."""
|
| 191 |
+
return text.lower().replace(" ", "-")
|
| 192 |
+
|
| 193 |
+
# the expected files and their order
|
| 194 |
+
EXPECTED_FILES = [
|
| 195 |
+
"tokenizer-training.md",
|
| 196 |
+
"tokenizer-evaluation.md",
|
| 197 |
+
"base-model-training.md",
|
| 198 |
+
"base-model-loss.md",
|
| 199 |
+
"base-model-evaluation.md",
|
| 200 |
+
"midtraining.md",
|
| 201 |
+
"chat-evaluation-mid.md",
|
| 202 |
+
"chat-sft.md",
|
| 203 |
+
"chat-evaluation-sft.md",
|
| 204 |
+
"chat-rl.md",
|
| 205 |
+
"chat-evaluation-rl.md",
|
| 206 |
+
]
|
| 207 |
+
# the metrics we're currently interested in
|
| 208 |
+
chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
|
| 209 |
+
|
| 210 |
+
def extract(section, keys):
|
| 211 |
+
"""simple def to extract a single key from a section"""
|
| 212 |
+
if not isinstance(keys, list):
|
| 213 |
+
keys = [keys] # convenience
|
| 214 |
+
out = {}
|
| 215 |
+
for line in section.split("\n"):
|
| 216 |
+
for key in keys:
|
| 217 |
+
if key in line:
|
| 218 |
+
out[key] = line.split(":")[1].strip()
|
| 219 |
+
return out
|
| 220 |
+
|
| 221 |
+
def extract_timestamp(content, prefix):
|
| 222 |
+
"""Extract timestamp from content with given prefix."""
|
| 223 |
+
for line in content.split('\n'):
|
| 224 |
+
if line.startswith(prefix):
|
| 225 |
+
time_str = line.split(":", 1)[1].strip()
|
| 226 |
+
try:
|
| 227 |
+
return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
| 228 |
+
except:
|
| 229 |
+
pass
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
class Report:
|
| 233 |
+
"""Maintains a bunch of logs, generates a final markdown report."""
|
| 234 |
+
|
| 235 |
+
def __init__(self, report_dir):
|
| 236 |
+
os.makedirs(report_dir, exist_ok=True)
|
| 237 |
+
self.report_dir = report_dir
|
| 238 |
+
|
| 239 |
+
def log(self, section, data):
|
| 240 |
+
"""Log a section of data to the report."""
|
| 241 |
+
slug = slugify(section)
|
| 242 |
+
file_name = f"{slug}.md"
|
| 243 |
+
file_path = os.path.join(self.report_dir, file_name)
|
| 244 |
+
with open(file_path, "w") as f:
|
| 245 |
+
f.write(f"## {section}\n")
|
| 246 |
+
f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
| 247 |
+
for item in data:
|
| 248 |
+
if not item:
|
| 249 |
+
# skip falsy values like None or empty dict etc.
|
| 250 |
+
continue
|
| 251 |
+
if isinstance(item, str):
|
| 252 |
+
# directly write the string
|
| 253 |
+
f.write(item)
|
| 254 |
+
else:
|
| 255 |
+
# render a dict
|
| 256 |
+
for k, v in item.items():
|
| 257 |
+
if isinstance(v, float):
|
| 258 |
+
vstr = f"{v:.4f}"
|
| 259 |
+
elif isinstance(v, int) and v >= 10000:
|
| 260 |
+
vstr = f"{v:,.0f}"
|
| 261 |
+
else:
|
| 262 |
+
vstr = str(v)
|
| 263 |
+
f.write(f"- {k}: {vstr}\n")
|
| 264 |
+
f.write("\n")
|
| 265 |
+
return file_path
|
| 266 |
+
|
| 267 |
+
def generate(self):
|
| 268 |
+
"""Generate the final report."""
|
| 269 |
+
report_dir = self.report_dir
|
| 270 |
+
report_file = os.path.join(report_dir, "report.md")
|
| 271 |
+
print(f"Generating report to {report_file}")
|
| 272 |
+
final_metrics = {} # the most important final metrics we'll add as table at the end
|
| 273 |
+
start_time = None
|
| 274 |
+
end_time = None
|
| 275 |
+
with open(report_file, "w") as out_file:
|
| 276 |
+
# write the header first
|
| 277 |
+
header_file = os.path.join(report_dir, "header.md")
|
| 278 |
+
if os.path.exists(header_file):
|
| 279 |
+
with open(header_file, "r") as f:
|
| 280 |
+
header_content = f.read()
|
| 281 |
+
out_file.write(header_content)
|
| 282 |
+
start_time = extract_timestamp(header_content, "Run started:")
|
| 283 |
+
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
| 284 |
+
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
| 285 |
+
bloat_data = bloat_data.group(1) if bloat_data else ""
|
| 286 |
+
# process all the individual sections
|
| 287 |
+
for file_name in EXPECTED_FILES:
|
| 288 |
+
section_file = os.path.join(report_dir, file_name)
|
| 289 |
+
if not os.path.exists(section_file):
|
| 290 |
+
print(f"Warning: {section_file} does not exist, skipping")
|
| 291 |
+
continue
|
| 292 |
+
with open(section_file, "r") as in_file:
|
| 293 |
+
section = in_file.read()
|
| 294 |
+
# Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
|
| 295 |
+
if "rl" not in file_name:
|
| 296 |
+
# Skip RL sections for end_time calculation because RL is experimental
|
| 297 |
+
end_time = extract_timestamp(section, "timestamp:")
|
| 298 |
+
# extract the most important metrics from the sections
|
| 299 |
+
if file_name == "base-model-evaluation.md":
|
| 300 |
+
final_metrics["base"] = extract(section, "CORE")
|
| 301 |
+
if file_name == "chat-evaluation-mid.md":
|
| 302 |
+
final_metrics["mid"] = extract(section, chat_metrics)
|
| 303 |
+
if file_name == "chat-evaluation-sft.md":
|
| 304 |
+
final_metrics["sft"] = extract(section, chat_metrics)
|
| 305 |
+
if file_name == "chat-evaluation-rl.md":
|
| 306 |
+
final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
|
| 307 |
+
# append this section of the report
|
| 308 |
+
out_file.write(section)
|
| 309 |
+
out_file.write("\n")
|
| 310 |
+
# add the final metrics table
|
| 311 |
+
out_file.write("## Summary\n\n")
|
| 312 |
+
# Copy over the bloat metrics from the header
|
| 313 |
+
out_file.write(bloat_data)
|
| 314 |
+
out_file.write("\n\n")
|
| 315 |
+
# Collect all unique metric names
|
| 316 |
+
all_metrics = set()
|
| 317 |
+
for stage_metrics in final_metrics.values():
|
| 318 |
+
all_metrics.update(stage_metrics.keys())
|
| 319 |
+
# Custom ordering: CORE first, ChatCORE last, rest in middle
|
| 320 |
+
all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
|
| 321 |
+
# Fixed column widths
|
| 322 |
+
stages = ["base", "mid", "sft", "rl"]
|
| 323 |
+
metric_width = 15
|
| 324 |
+
value_width = 8
|
| 325 |
+
# Write table header
|
| 326 |
+
header = f"| {'Metric'.ljust(metric_width)} |"
|
| 327 |
+
for stage in stages:
|
| 328 |
+
header += f" {stage.upper().ljust(value_width)} |"
|
| 329 |
+
out_file.write(header + "\n")
|
| 330 |
+
# Write separator
|
| 331 |
+
separator = f"|{'-' * (metric_width + 2)}|"
|
| 332 |
+
for stage in stages:
|
| 333 |
+
separator += f"{'-' * (value_width + 2)}|"
|
| 334 |
+
out_file.write(separator + "\n")
|
| 335 |
+
# Write table rows
|
| 336 |
+
for metric in all_metrics:
|
| 337 |
+
row = f"| {metric.ljust(metric_width)} |"
|
| 338 |
+
for stage in stages:
|
| 339 |
+
value = final_metrics.get(stage, {}).get(metric, "-")
|
| 340 |
+
row += f" {str(value).ljust(value_width)} |"
|
| 341 |
+
out_file.write(row + "\n")
|
| 342 |
+
out_file.write("\n")
|
| 343 |
+
# Calculate and write total wall clock time
|
| 344 |
+
if start_time and end_time:
|
| 345 |
+
duration = end_time - start_time
|
| 346 |
+
total_seconds = int(duration.total_seconds())
|
| 347 |
+
hours = total_seconds // 3600
|
| 348 |
+
minutes = (total_seconds % 3600) // 60
|
| 349 |
+
out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
|
| 350 |
+
else:
|
| 351 |
+
out_file.write("Total wall clock time: unknown\n")
|
| 352 |
+
# also cp the report.md file to current directory
|
| 353 |
+
print(f"Copying report.md to current directory for convenience")
|
| 354 |
+
shutil.copy(report_file, "report.md")
|
| 355 |
+
return report_file
|
| 356 |
+
|
| 357 |
+
def reset(self):
|
| 358 |
+
"""Reset the report."""
|
| 359 |
+
# Remove section files
|
| 360 |
+
for file_name in EXPECTED_FILES:
|
| 361 |
+
file_path = os.path.join(self.report_dir, file_name)
|
| 362 |
+
if os.path.exists(file_path):
|
| 363 |
+
os.remove(file_path)
|
| 364 |
+
# Remove report.md if it exists
|
| 365 |
+
report_file = os.path.join(self.report_dir, "report.md")
|
| 366 |
+
if os.path.exists(report_file):
|
| 367 |
+
os.remove(report_file)
|
| 368 |
+
# Generate and write the header section with start timestamp
|
| 369 |
+
header_file = os.path.join(self.report_dir, "header.md")
|
| 370 |
+
header = generate_header()
|
| 371 |
+
start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 372 |
+
with open(header_file, "w") as f:
|
| 373 |
+
f.write(header)
|
| 374 |
+
f.write(f"Run started: {start_time}\n\n---\n\n")
|
| 375 |
+
print(f"Reset report and wrote header to {header_file}")
|
| 376 |
+
|
| 377 |
+
# -----------------------------------------------------------------------------
|
| 378 |
+
# nanochat-specific convenience functions
|
| 379 |
+
|
| 380 |
+
class DummyReport:
|
| 381 |
+
def log(self, *args, **kwargs):
|
| 382 |
+
pass
|
| 383 |
+
def reset(self, *args, **kwargs):
|
| 384 |
+
pass
|
| 385 |
+
|
| 386 |
+
def get_report():
|
| 387 |
+
# just for convenience, only rank 0 logs to report
|
| 388 |
+
from nanochat.common import get_base_dir, get_dist_info
|
| 389 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 390 |
+
if ddp_rank == 0:
|
| 391 |
+
report_dir = os.path.join(get_base_dir(), "report")
|
| 392 |
+
return Report(report_dir)
|
| 393 |
+
else:
|
| 394 |
+
return DummyReport()
|
| 395 |
+
|
| 396 |
+
if __name__ == "__main__":
|
| 397 |
+
import argparse
|
| 398 |
+
parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
|
| 399 |
+
parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
|
| 400 |
+
args = parser.parse_args()
|
| 401 |
+
if args.command == "generate":
|
| 402 |
+
get_report().generate()
|
| 403 |
+
elif args.command == "reset":
|
| 404 |
+
get_report().reset()
|
nanochat/tokenizer.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BPE Tokenizer in the style of GPT-4.
|
| 3 |
+
|
| 4 |
+
Two implementations are available:
|
| 5 |
+
1) HuggingFace Tokenizer that can do both training and inference but is really confusing
|
| 6 |
+
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import copy
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
|
| 13 |
+
SPECIAL_TOKENS = [
|
| 14 |
+
# every document begins with the Beginning of Sequence (BOS) token that delimits documents
|
| 15 |
+
"<|bos|>",
|
| 16 |
+
# tokens below are only used during finetuning to render Conversations into token ids
|
| 17 |
+
"<|user_start|>", # user messages
|
| 18 |
+
"<|user_end|>",
|
| 19 |
+
"<|assistant_start|>", # assistant messages
|
| 20 |
+
"<|assistant_end|>",
|
| 21 |
+
"<|python_start|>", # assistant invokes python REPL tool
|
| 22 |
+
"<|python_end|>",
|
| 23 |
+
"<|output_start|>", # python REPL outputs back to assistant
|
| 24 |
+
"<|output_end|>",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
|
| 28 |
+
# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
|
| 29 |
+
# I haven't validated that this is actually a good idea, TODO.
|
| 30 |
+
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
| 31 |
+
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
|
| 34 |
+
from tokenizers import Tokenizer as HFTokenizer
|
| 35 |
+
from tokenizers import pre_tokenizers, decoders, Regex
|
| 36 |
+
from tokenizers.models import BPE
|
| 37 |
+
from tokenizers.trainers import BpeTrainer
|
| 38 |
+
|
| 39 |
+
class HuggingFaceTokenizer:
|
| 40 |
+
"""Light wrapper around HuggingFace Tokenizer for some utilities"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, tokenizer):
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_pretrained(cls, hf_path):
|
| 47 |
+
# init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
|
| 48 |
+
tokenizer = HFTokenizer.from_pretrained(hf_path)
|
| 49 |
+
return cls(tokenizer)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_directory(cls, tokenizer_dir):
|
| 53 |
+
# init from a local directory on disk (e.g. "out/tokenizer")
|
| 54 |
+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
| 55 |
+
tokenizer = HFTokenizer.from_file(tokenizer_path)
|
| 56 |
+
return cls(tokenizer)
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def train_from_iterator(cls, text_iterator, vocab_size):
|
| 60 |
+
# train from an iterator of text
|
| 61 |
+
# Configure the HuggingFace Tokenizer
|
| 62 |
+
tokenizer = HFTokenizer(BPE(
|
| 63 |
+
byte_fallback=True, # needed!
|
| 64 |
+
unk_token=None,
|
| 65 |
+
fuse_unk=False,
|
| 66 |
+
))
|
| 67 |
+
# Normalizer: None
|
| 68 |
+
tokenizer.normalizer = None
|
| 69 |
+
# Pre-tokenizer: GPT-4 style
|
| 70 |
+
# the regex pattern used by GPT-4 to split text into groups before BPE
|
| 71 |
+
# NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
|
| 72 |
+
# very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
|
| 73 |
+
# (but I haven't validated this! TODO)
|
| 74 |
+
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
|
| 75 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
| 76 |
+
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
|
| 77 |
+
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
|
| 78 |
+
])
|
| 79 |
+
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
|
| 80 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 81 |
+
# Post-processor: None
|
| 82 |
+
tokenizer.post_processor = None
|
| 83 |
+
# Trainer: BPE
|
| 84 |
+
trainer = BpeTrainer(
|
| 85 |
+
vocab_size=vocab_size,
|
| 86 |
+
show_progress=True,
|
| 87 |
+
min_frequency=0, # no minimum frequency
|
| 88 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 89 |
+
special_tokens=SPECIAL_TOKENS,
|
| 90 |
+
)
|
| 91 |
+
# Kick off the training
|
| 92 |
+
tokenizer.train_from_iterator(text_iterator, trainer)
|
| 93 |
+
return cls(tokenizer)
|
| 94 |
+
|
| 95 |
+
def get_vocab_size(self):
|
| 96 |
+
return self.tokenizer.get_vocab_size()
|
| 97 |
+
|
| 98 |
+
def get_special_tokens(self):
|
| 99 |
+
special_tokens_map = self.tokenizer.get_added_tokens_decoder()
|
| 100 |
+
special_tokens = [w.content for w in special_tokens_map.values()]
|
| 101 |
+
return special_tokens
|
| 102 |
+
|
| 103 |
+
def id_to_token(self, id):
|
| 104 |
+
return self.tokenizer.id_to_token(id)
|
| 105 |
+
|
| 106 |
+
def _encode_one(self, text, prepend=None, append=None):
|
| 107 |
+
# encode a single string
|
| 108 |
+
# prepend/append can be either a string of a special token or a token id directly.
|
| 109 |
+
assert isinstance(text, str)
|
| 110 |
+
ids = []
|
| 111 |
+
if prepend is not None:
|
| 112 |
+
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
| 113 |
+
ids.append(prepend_id)
|
| 114 |
+
ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
|
| 115 |
+
if append is not None:
|
| 116 |
+
append_id = append if isinstance(append, int) else self.encode_special(append)
|
| 117 |
+
ids.append(append_id)
|
| 118 |
+
return ids
|
| 119 |
+
|
| 120 |
+
def encode_special(self, text):
|
| 121 |
+
# encode a single special token via exact match
|
| 122 |
+
return self.tokenizer.token_to_id(text)
|
| 123 |
+
|
| 124 |
+
def get_bos_token_id(self):
|
| 125 |
+
bos = self.encode_special("<|bos|>")
|
| 126 |
+
return bos
|
| 127 |
+
|
| 128 |
+
def encode(self, text, *args, **kwargs):
|
| 129 |
+
if isinstance(text, str):
|
| 130 |
+
return self._encode_one(text, *args, **kwargs)
|
| 131 |
+
elif isinstance(text, list):
|
| 132 |
+
return [self._encode_one(t, *args, **kwargs) for t in text]
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 135 |
+
|
| 136 |
+
def __call__(self, *args, **kwargs):
|
| 137 |
+
return self.encode(*args, **kwargs)
|
| 138 |
+
|
| 139 |
+
def decode(self, ids):
|
| 140 |
+
return self.tokenizer.decode(ids, skip_special_tokens=False)
|
| 141 |
+
|
| 142 |
+
def save(self, tokenizer_dir):
|
| 143 |
+
# save the tokenizer to disk
|
| 144 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 145 |
+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
| 146 |
+
self.tokenizer.save(tokenizer_path)
|
| 147 |
+
print(f"Saved tokenizer to {tokenizer_path}")
|
| 148 |
+
|
| 149 |
+
# -----------------------------------------------------------------------------
|
| 150 |
+
# Tokenizer based on rustbpe + tiktoken combo
|
| 151 |
+
import pickle
|
| 152 |
+
import rustbpe
|
| 153 |
+
import tiktoken
|
| 154 |
+
|
| 155 |
+
class RustBPETokenizer:
|
| 156 |
+
"""Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
|
| 157 |
+
|
| 158 |
+
def __init__(self, enc, bos_token):
|
| 159 |
+
self.enc = enc
|
| 160 |
+
self.bos_token_id = self.encode_special(bos_token)
|
| 161 |
+
|
| 162 |
+
@classmethod
|
| 163 |
+
def train_from_iterator(cls, text_iterator, vocab_size):
|
| 164 |
+
# 1) train using rustbpe
|
| 165 |
+
tokenizer = rustbpe.Tokenizer()
|
| 166 |
+
# the special tokens are inserted later in __init__, we don't train them here
|
| 167 |
+
vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
|
| 168 |
+
assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
|
| 169 |
+
tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
|
| 170 |
+
# 2) construct the associated tiktoken encoding for inference
|
| 171 |
+
pattern = tokenizer.get_pattern()
|
| 172 |
+
mergeable_ranks_list = tokenizer.get_mergeable_ranks()
|
| 173 |
+
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
|
| 174 |
+
tokens_offset = len(mergeable_ranks)
|
| 175 |
+
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
|
| 176 |
+
enc = tiktoken.Encoding(
|
| 177 |
+
name="rustbpe",
|
| 178 |
+
pat_str=pattern,
|
| 179 |
+
mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
|
| 180 |
+
special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
|
| 181 |
+
)
|
| 182 |
+
return cls(enc, "<|bos|>")
|
| 183 |
+
|
| 184 |
+
@classmethod
|
| 185 |
+
def from_directory(cls, tokenizer_dir):
|
| 186 |
+
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
|
| 187 |
+
with open(pickle_path, "rb") as f:
|
| 188 |
+
enc = pickle.load(f)
|
| 189 |
+
return cls(enc, "<|bos|>")
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def from_pretrained(cls, tiktoken_name):
|
| 193 |
+
# https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
|
| 194 |
+
enc = tiktoken.get_encoding(tiktoken_name)
|
| 195 |
+
# tiktoken calls the special document delimiter token "<|endoftext|>"
|
| 196 |
+
# yes this is confusing because this token is almost always PREPENDED to the beginning of the document
|
| 197 |
+
# it most often is used to signal the start of a new sequence to the LLM during inference etc.
|
| 198 |
+
# so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
|
| 199 |
+
return cls(enc, "<|endoftext|>")
|
| 200 |
+
|
| 201 |
+
def get_vocab_size(self):
|
| 202 |
+
return self.enc.n_vocab
|
| 203 |
+
|
| 204 |
+
def get_special_tokens(self):
|
| 205 |
+
return self.enc.special_tokens_set
|
| 206 |
+
|
| 207 |
+
def id_to_token(self, id):
|
| 208 |
+
return self.enc.decode([id])
|
| 209 |
+
|
| 210 |
+
@lru_cache(maxsize=32)
|
| 211 |
+
def encode_special(self, text):
|
| 212 |
+
return self.enc.encode_single_token(text)
|
| 213 |
+
|
| 214 |
+
def get_bos_token_id(self):
|
| 215 |
+
return self.bos_token_id
|
| 216 |
+
|
| 217 |
+
def encode(self, text, prepend=None, append=None, num_threads=8):
|
| 218 |
+
# text can be either a string or a list of strings
|
| 219 |
+
|
| 220 |
+
if prepend is not None:
|
| 221 |
+
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
|
| 222 |
+
if append is not None:
|
| 223 |
+
append_id = append if isinstance(append, int) else self.encode_special(append)
|
| 224 |
+
|
| 225 |
+
if isinstance(text, str):
|
| 226 |
+
ids = self.enc.encode_ordinary(text)
|
| 227 |
+
if prepend is not None:
|
| 228 |
+
ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
|
| 229 |
+
if append is not None:
|
| 230 |
+
ids.append(append_id)
|
| 231 |
+
elif isinstance(text, list):
|
| 232 |
+
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
| 233 |
+
if prepend is not None:
|
| 234 |
+
for ids_row in ids:
|
| 235 |
+
ids_row.insert(0, prepend_id) # TODO: same
|
| 236 |
+
if append is not None:
|
| 237 |
+
for ids_row in ids:
|
| 238 |
+
ids_row.append(append_id)
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 241 |
+
|
| 242 |
+
return ids
|
| 243 |
+
|
| 244 |
+
def __call__(self, *args, **kwargs):
|
| 245 |
+
return self.encode(*args, **kwargs)
|
| 246 |
+
|
| 247 |
+
def decode(self, ids):
|
| 248 |
+
return self.enc.decode(ids)
|
| 249 |
+
|
| 250 |
+
def save(self, tokenizer_dir):
|
| 251 |
+
# save the encoding object to disk
|
| 252 |
+
os.makedirs(tokenizer_dir, exist_ok=True)
|
| 253 |
+
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
|
| 254 |
+
with open(pickle_path, "wb") as f:
|
| 255 |
+
pickle.dump(self.enc, f)
|
| 256 |
+
print(f"Saved tokenizer encoding to {pickle_path}")
|
| 257 |
+
|
| 258 |
+
def render_conversation(self, conversation, max_tokens=2048):
|
| 259 |
+
"""
|
| 260 |
+
Tokenize a single Chat conversation (which we call a "doc" or "document" here).
|
| 261 |
+
Returns:
|
| 262 |
+
- ids: list[int] is a list of token ids of this rendered conversation
|
| 263 |
+
- mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
|
| 264 |
+
"""
|
| 265 |
+
# ids, masks that we will return and a helper function to help build them up.
|
| 266 |
+
ids, mask = [], []
|
| 267 |
+
def add_tokens(token_ids, mask_val):
|
| 268 |
+
if isinstance(token_ids, int):
|
| 269 |
+
token_ids = [token_ids]
|
| 270 |
+
ids.extend(token_ids)
|
| 271 |
+
mask.extend([mask_val] * len(token_ids))
|
| 272 |
+
|
| 273 |
+
# sometimes the first message is a system message...
|
| 274 |
+
# => just merge it with the second (user) message
|
| 275 |
+
if conversation["messages"][0]["role"] == "system":
|
| 276 |
+
# some conversation surgery is necessary here for now...
|
| 277 |
+
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
| 278 |
+
messages = conversation["messages"]
|
| 279 |
+
assert messages[1]["role"] == "user", "System message must be followed by a user message"
|
| 280 |
+
messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
|
| 281 |
+
messages = messages[1:]
|
| 282 |
+
else:
|
| 283 |
+
messages = conversation["messages"]
|
| 284 |
+
assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
|
| 285 |
+
|
| 286 |
+
# fetch all the special tokens we need
|
| 287 |
+
bos = self.get_bos_token_id()
|
| 288 |
+
user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
|
| 289 |
+
assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
|
| 290 |
+
python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
|
| 291 |
+
output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
|
| 292 |
+
|
| 293 |
+
# now we can tokenize the conversation
|
| 294 |
+
add_tokens(bos, 0)
|
| 295 |
+
for i, message in enumerate(messages):
|
| 296 |
+
|
| 297 |
+
# some sanity checking here around assumptions, to prevent footguns
|
| 298 |
+
must_be_from = "user" if i % 2 == 0 else "assistant"
|
| 299 |
+
assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
|
| 300 |
+
|
| 301 |
+
# content can be either a simple string or a list of parts (e.g. containing tool calls)
|
| 302 |
+
content = message["content"]
|
| 303 |
+
|
| 304 |
+
if message["role"] == "user":
|
| 305 |
+
assert isinstance(content, str), "User messages are simply expected to be strings"
|
| 306 |
+
value_ids = self.encode(content)
|
| 307 |
+
add_tokens(user_start, 0)
|
| 308 |
+
add_tokens(value_ids, 0)
|
| 309 |
+
add_tokens(user_end, 0)
|
| 310 |
+
elif message["role"] == "assistant":
|
| 311 |
+
add_tokens(assistant_start, 0)
|
| 312 |
+
if isinstance(content, str):
|
| 313 |
+
# simple string => simply add the tokens
|
| 314 |
+
value_ids = self.encode(content)
|
| 315 |
+
add_tokens(value_ids, 1)
|
| 316 |
+
elif isinstance(content, list):
|
| 317 |
+
for part in content:
|
| 318 |
+
value_ids = self.encode(part["text"])
|
| 319 |
+
if part["type"] == "text":
|
| 320 |
+
# string part => simply add the tokens
|
| 321 |
+
add_tokens(value_ids, 1)
|
| 322 |
+
elif part["type"] == "python":
|
| 323 |
+
# python tool call => add the tokens inside <|python_start|> and <|python_end|>
|
| 324 |
+
add_tokens(python_start, 1)
|
| 325 |
+
add_tokens(value_ids, 1)
|
| 326 |
+
add_tokens(python_end, 1)
|
| 327 |
+
elif part["type"] == "python_output":
|
| 328 |
+
# python output => add the tokens inside <|output_start|> and <|output_end|>
|
| 329 |
+
# none of these tokens are supervised because the tokens come from Python at test time
|
| 330 |
+
add_tokens(output_start, 0)
|
| 331 |
+
add_tokens(value_ids, 0)
|
| 332 |
+
add_tokens(output_end, 0)
|
| 333 |
+
else:
|
| 334 |
+
raise ValueError(f"Unknown part type: {part['type']}")
|
| 335 |
+
else:
|
| 336 |
+
raise ValueError(f"Unknown content type: {type(content)}")
|
| 337 |
+
add_tokens(assistant_end, 1)
|
| 338 |
+
|
| 339 |
+
# truncate to max_tokens tokens MAX (helps prevent OOMs)
|
| 340 |
+
ids = ids[:max_tokens]
|
| 341 |
+
mask = mask[:max_tokens]
|
| 342 |
+
return ids, mask
|
| 343 |
+
|
| 344 |
+
def visualize_tokenization(self, ids, mask):
|
| 345 |
+
"""Small helper function useful in debugging: visualize the tokenization of render_conversation"""
|
| 346 |
+
RED = '\033[91m'
|
| 347 |
+
GREEN = '\033[92m'
|
| 348 |
+
RESET = '\033[0m'
|
| 349 |
+
tokens = []
|
| 350 |
+
for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
|
| 351 |
+
token_str = self.decode([token_id])
|
| 352 |
+
color = GREEN if mask_val == 1 else RED
|
| 353 |
+
tokens.append(f"{color}{token_str}{RESET}")
|
| 354 |
+
return '|'.join(tokens)
|
| 355 |
+
|
| 356 |
+
def render_for_completion(self, conversation):
|
| 357 |
+
"""
|
| 358 |
+
Used during Reinforcement Learning. In that setting, we want to
|
| 359 |
+
render the conversation priming the Assistant for a completion.
|
| 360 |
+
Unlike the Chat SFT case, we don't need to return the mask.
|
| 361 |
+
"""
|
| 362 |
+
# We have some surgery to do: we need to pop the last message (of the Assistant)
|
| 363 |
+
conversation = copy.deepcopy(conversation) # avoid mutating the original
|
| 364 |
+
messages = conversation["messages"]
|
| 365 |
+
assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
|
| 366 |
+
messages.pop() # remove the last message (of the Assistant) inplace
|
| 367 |
+
|
| 368 |
+
# Now tokenize the conversation
|
| 369 |
+
ids, mask = self.render_conversation(conversation)
|
| 370 |
+
|
| 371 |
+
# Finally, to prime the Assistant for a completion, append the Assistant start token
|
| 372 |
+
assistant_start = self.encode_special("<|assistant_start|>")
|
| 373 |
+
ids.append(assistant_start)
|
| 374 |
+
return ids
|
| 375 |
+
|
| 376 |
+
# -----------------------------------------------------------------------------
|
| 377 |
+
# nanochat-specific convenience functions
|
| 378 |
+
|
| 379 |
+
def get_tokenizer():
|
| 380 |
+
from nanochat.common import get_base_dir
|
| 381 |
+
base_dir = get_base_dir()
|
| 382 |
+
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
| 383 |
+
# return HuggingFaceTokenizer.from_directory(tokenizer_dir)
|
| 384 |
+
return RustBPETokenizer.from_directory(tokenizer_dir)
|
| 385 |
+
|
| 386 |
+
def get_token_bytes(device="cpu"):
|
| 387 |
+
import torch
|
| 388 |
+
from nanochat.common import get_base_dir
|
| 389 |
+
base_dir = get_base_dir()
|
| 390 |
+
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
| 391 |
+
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
| 392 |
+
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
| 393 |
+
with open(token_bytes_path, "rb") as f:
|
| 394 |
+
token_bytes = torch.load(f, map_location=device)
|
| 395 |
+
return token_bytes
|
nanochat/ui.html
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>NanoChat</title>
|
| 7 |
+
<link rel="icon" type="image/svg+xml" href="/logo.svg">
|
| 8 |
+
<style>
|
| 9 |
+
:root {
|
| 10 |
+
color-scheme: light;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
* {
|
| 14 |
+
box-sizing: border-box;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
body {
|
| 18 |
+
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
| 19 |
+
background-color: #ffffff;
|
| 20 |
+
color: #111827;
|
| 21 |
+
min-height: 100vh;
|
| 22 |
+
margin: 0;
|
| 23 |
+
display: flex;
|
| 24 |
+
flex-direction: column;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.header {
|
| 28 |
+
background-color: #ffffff;
|
| 29 |
+
padding: 1.25rem 1.5rem;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.header-left {
|
| 33 |
+
display: flex;
|
| 34 |
+
align-items: center;
|
| 35 |
+
gap: 0.75rem;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.header-logo {
|
| 39 |
+
height: 32px;
|
| 40 |
+
width: auto;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.header h1 {
|
| 44 |
+
font-size: 1.25rem;
|
| 45 |
+
font-weight: 600;
|
| 46 |
+
margin: 0;
|
| 47 |
+
color: #111827;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.new-conversation-btn {
|
| 51 |
+
width: 32px;
|
| 52 |
+
height: 32px;
|
| 53 |
+
padding: 0;
|
| 54 |
+
border: 1px solid #e5e7eb;
|
| 55 |
+
border-radius: 0.5rem;
|
| 56 |
+
background-color: #ffffff;
|
| 57 |
+
color: #6b7280;
|
| 58 |
+
cursor: pointer;
|
| 59 |
+
display: flex;
|
| 60 |
+
align-items: center;
|
| 61 |
+
justify-content: center;
|
| 62 |
+
transition: all 0.2s ease;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.new-conversation-btn:hover {
|
| 66 |
+
background-color: #f3f4f6;
|
| 67 |
+
border-color: #d1d5db;
|
| 68 |
+
color: #374151;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.chat-container {
|
| 72 |
+
flex: 1;
|
| 73 |
+
overflow-y: auto;
|
| 74 |
+
background-color: #ffffff;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
.chat-wrapper {
|
| 78 |
+
max-width: 48rem;
|
| 79 |
+
margin: 0 auto;
|
| 80 |
+
padding: 2rem 1.5rem 3rem;
|
| 81 |
+
display: flex;
|
| 82 |
+
flex-direction: column;
|
| 83 |
+
gap: 0.75rem;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
.message {
|
| 87 |
+
display: flex;
|
| 88 |
+
justify-content: flex-start;
|
| 89 |
+
margin-bottom: 0.5rem;
|
| 90 |
+
color: #0d0d0d;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.message.assistant {
|
| 94 |
+
justify-content: flex-start;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.message.user {
|
| 98 |
+
justify-content: flex-end;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
.message-content {
|
| 102 |
+
white-space: pre-wrap;
|
| 103 |
+
line-height: 1.6;
|
| 104 |
+
max-width: 100%;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
.message.assistant .message-content {
|
| 108 |
+
background: transparent;
|
| 109 |
+
border: none;
|
| 110 |
+
padding: 0.25rem 0;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.message.user .message-content {
|
| 114 |
+
background-color: #f3f4f6;
|
| 115 |
+
border-radius: 1.25rem;
|
| 116 |
+
padding: 0.8rem 1rem;
|
| 117 |
+
max-width: 65%;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
.input-container {
|
| 121 |
+
background-color: #ffffff;
|
| 122 |
+
padding: 1rem;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
.input-wrapper {
|
| 126 |
+
max-width: 48rem;
|
| 127 |
+
margin: 0 auto;
|
| 128 |
+
display: flex;
|
| 129 |
+
gap: 0.75rem;
|
| 130 |
+
align-items: flex-end;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
.chat-input {
|
| 134 |
+
flex: 1;
|
| 135 |
+
padding: 0.8rem 1rem;
|
| 136 |
+
border: 1px solid #d1d5db;
|
| 137 |
+
border-radius: 0.75rem;
|
| 138 |
+
background-color: #ffffff;
|
| 139 |
+
color: #111827;
|
| 140 |
+
font-size: 1rem;
|
| 141 |
+
line-height: 1.5;
|
| 142 |
+
resize: none;
|
| 143 |
+
outline: none;
|
| 144 |
+
min-height: 54px;
|
| 145 |
+
max-height: 200px;
|
| 146 |
+
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
.chat-input::placeholder {
|
| 150 |
+
color: #9ca3af;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
.chat-input:focus {
|
| 154 |
+
border-color: #2563eb;
|
| 155 |
+
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
.send-button {
|
| 159 |
+
flex-shrink: 0;
|
| 160 |
+
padding: 0;
|
| 161 |
+
width: 54px;
|
| 162 |
+
height: 54px;
|
| 163 |
+
border: 1px solid #111827;
|
| 164 |
+
border-radius: 0.75rem;
|
| 165 |
+
background-color: #111827;
|
| 166 |
+
color: #ffffff;
|
| 167 |
+
display: flex;
|
| 168 |
+
align-items: center;
|
| 169 |
+
justify-content: center;
|
| 170 |
+
cursor: pointer;
|
| 171 |
+
transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
.send-button:hover:not(:disabled) {
|
| 175 |
+
background-color: #2563eb;
|
| 176 |
+
border-color: #2563eb;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
.send-button:disabled {
|
| 180 |
+
cursor: not-allowed;
|
| 181 |
+
border-color: #d1d5db;
|
| 182 |
+
background-color: #e5e7eb;
|
| 183 |
+
color: #9ca3af;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
.typing-indicator {
|
| 187 |
+
display: inline-block;
|
| 188 |
+
color: #6b7280;
|
| 189 |
+
letter-spacing: 0.15em;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
.typing-indicator::after {
|
| 193 |
+
content: '···';
|
| 194 |
+
animation: typing 1.4s infinite;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
@keyframes typing {
|
| 198 |
+
0%, 60%, 100% { opacity: 0.2; }
|
| 199 |
+
30% { opacity: 1; }
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
.error-message {
|
| 203 |
+
background-color: #fee2e2;
|
| 204 |
+
border: 1px solid #fecaca;
|
| 205 |
+
color: #b91c1c;
|
| 206 |
+
padding: 0.75rem 1rem;
|
| 207 |
+
border-radius: 0.75rem;
|
| 208 |
+
margin-top: 0.5rem;
|
| 209 |
+
}
|
| 210 |
+
</style>
|
| 211 |
+
</head>
|
| 212 |
+
<body>
|
| 213 |
+
<div class="header">
|
| 214 |
+
<div class="header-left">
|
| 215 |
+
<button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
|
| 216 |
+
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 217 |
+
<path d="M12 5v14"></path>
|
| 218 |
+
<path d="M5 12h14"></path>
|
| 219 |
+
</svg>
|
| 220 |
+
</button>
|
| 221 |
+
<h1>nanochat</h1>
|
| 222 |
+
</div>
|
| 223 |
+
</div>
|
| 224 |
+
|
| 225 |
+
<div class="chat-container" id="chatContainer">
|
| 226 |
+
<div class="chat-wrapper" id="chatWrapper">
|
| 227 |
+
<!-- Messages will be added here -->
|
| 228 |
+
</div>
|
| 229 |
+
</div>
|
| 230 |
+
|
| 231 |
+
<div class="input-container">
|
| 232 |
+
<div class="input-wrapper">
|
| 233 |
+
<textarea
|
| 234 |
+
id="chatInput"
|
| 235 |
+
class="chat-input"
|
| 236 |
+
placeholder="Ask anything"
|
| 237 |
+
rows="1"
|
| 238 |
+
onkeydown="handleKeyDown(event)"
|
| 239 |
+
></textarea>
|
| 240 |
+
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
|
| 241 |
+
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 242 |
+
<path d="M22 2L11 13"></path>
|
| 243 |
+
<path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
|
| 244 |
+
</svg>
|
| 245 |
+
</button>
|
| 246 |
+
</div>
|
| 247 |
+
</div>
|
| 248 |
+
|
| 249 |
+
<script>
|
| 250 |
+
const API_URL = '';
|
| 251 |
+
const chatContainer = document.getElementById('chatContainer');
|
| 252 |
+
const chatWrapper = document.getElementById('chatWrapper');
|
| 253 |
+
const chatInput = document.getElementById('chatInput');
|
| 254 |
+
const sendButton = document.getElementById('sendButton');
|
| 255 |
+
|
| 256 |
+
let messages = [];
|
| 257 |
+
let isGenerating = false;
|
| 258 |
+
|
| 259 |
+
chatInput.addEventListener('input', function() {
|
| 260 |
+
this.style.height = 'auto';
|
| 261 |
+
this.style.height = Math.min(this.scrollHeight, 200) + 'px';
|
| 262 |
+
sendButton.disabled = !this.value.trim() || isGenerating;
|
| 263 |
+
});
|
| 264 |
+
|
| 265 |
+
function handleKeyDown(event) {
|
| 266 |
+
if (event.key === 'Enter' && !event.shiftKey) {
|
| 267 |
+
event.preventDefault();
|
| 268 |
+
sendMessage();
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
document.addEventListener('keydown', function(event) {
|
| 273 |
+
// Ctrl+Shift+N for new conversation
|
| 274 |
+
if (event.ctrlKey && event.shiftKey && event.key === 'N') {
|
| 275 |
+
event.preventDefault();
|
| 276 |
+
if (!isGenerating) {
|
| 277 |
+
newConversation();
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
});
|
| 281 |
+
|
| 282 |
+
function newConversation() {
|
| 283 |
+
messages = [];
|
| 284 |
+
chatWrapper.innerHTML = '';
|
| 285 |
+
chatInput.value = '';
|
| 286 |
+
chatInput.style.height = 'auto';
|
| 287 |
+
sendButton.disabled = false;
|
| 288 |
+
isGenerating = false;
|
| 289 |
+
chatInput.focus();
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
function addMessage(role, content) {
|
| 293 |
+
const messageDiv = document.createElement('div');
|
| 294 |
+
messageDiv.className = `message ${role}`;
|
| 295 |
+
|
| 296 |
+
const contentDiv = document.createElement('div');
|
| 297 |
+
contentDiv.className = 'message-content';
|
| 298 |
+
contentDiv.textContent = content;
|
| 299 |
+
|
| 300 |
+
messageDiv.appendChild(contentDiv);
|
| 301 |
+
chatWrapper.appendChild(messageDiv);
|
| 302 |
+
|
| 303 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 304 |
+
return contentDiv;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
async function sendMessage() {
|
| 308 |
+
const message = chatInput.value.trim();
|
| 309 |
+
if (!message || isGenerating) return;
|
| 310 |
+
|
| 311 |
+
isGenerating = true;
|
| 312 |
+
chatInput.value = '';
|
| 313 |
+
chatInput.style.height = 'auto';
|
| 314 |
+
sendButton.disabled = true;
|
| 315 |
+
|
| 316 |
+
messages.push({ role: 'user', content: message });
|
| 317 |
+
addMessage('user', message);
|
| 318 |
+
|
| 319 |
+
const assistantContent = addMessage('assistant', '');
|
| 320 |
+
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
| 321 |
+
|
| 322 |
+
try {
|
| 323 |
+
const response = await fetch(`${API_URL}/chat/completions`, {
|
| 324 |
+
method: 'POST',
|
| 325 |
+
headers: {
|
| 326 |
+
'Content-Type': 'application/json',
|
| 327 |
+
},
|
| 328 |
+
body: JSON.stringify({
|
| 329 |
+
messages: messages,
|
| 330 |
+
stream: true,
|
| 331 |
+
temperature: 0.8,
|
| 332 |
+
max_tokens: 512
|
| 333 |
+
}),
|
| 334 |
+
});
|
| 335 |
+
|
| 336 |
+
if (!response.ok) {
|
| 337 |
+
throw new Error(`HTTP error! status: ${response.status}`);
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
const reader = response.body.getReader();
|
| 341 |
+
const decoder = new TextDecoder();
|
| 342 |
+
let fullResponse = '';
|
| 343 |
+
assistantContent.textContent = '';
|
| 344 |
+
|
| 345 |
+
while (true) {
|
| 346 |
+
const { done, value } = await reader.read();
|
| 347 |
+
if (done) break;
|
| 348 |
+
|
| 349 |
+
const chunk = decoder.decode(value);
|
| 350 |
+
const lines = chunk.split('\n');
|
| 351 |
+
|
| 352 |
+
for (const line of lines) {
|
| 353 |
+
if (line.startsWith('data: ')) {
|
| 354 |
+
try {
|
| 355 |
+
const data = JSON.parse(line.slice(6));
|
| 356 |
+
if (data.token) {
|
| 357 |
+
fullResponse += data.token;
|
| 358 |
+
assistantContent.textContent = fullResponse;
|
| 359 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 360 |
+
}
|
| 361 |
+
} catch (e) {
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
messages.push({ role: 'assistant', content: fullResponse });
|
| 368 |
+
|
| 369 |
+
} catch (error) {
|
| 370 |
+
console.error('Error:', error);
|
| 371 |
+
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
| 372 |
+
} finally {
|
| 373 |
+
isGenerating = false;
|
| 374 |
+
sendButton.disabled = !chatInput.value.trim();
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
sendButton.disabled = false;
|
| 379 |
+
|
| 380 |
+
// Autofocus the chat input on page load
|
| 381 |
+
chatInput.focus();
|
| 382 |
+
|
| 383 |
+
fetch(`${API_URL}/health`)
|
| 384 |
+
.then(response => response.json())
|
| 385 |
+
.then(data => {
|
| 386 |
+
console.log('Engine status:', data);
|
| 387 |
+
})
|
| 388 |
+
.catch(error => {
|
| 389 |
+
console.error('Engine not available:', error);
|
| 390 |
+
chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
|
| 391 |
+
});
|
| 392 |
+
</script>
|
| 393 |
+
</body>
|
| 394 |
+
</html>
|