loocorez commited on
Commit
df7f6f1
·
verified ·
1 Parent(s): 6e96dcc

Upload folder using huggingface_hub

Browse files
nanochat/__init__.py ADDED
File without changes
nanochat/adamw.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
3
+ Not a general optimizer! But works for our specific use.
4
+ """
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch import Tensor
8
+
9
+
10
+ class DistAdamW(torch.optim.Optimizer):
11
+ """
12
+ Distributed AdamW optimizer.
13
+ In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
14
+ """
15
+ def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
16
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
17
+ super().__init__(param_groups, defaults)
18
+
19
+ @torch.compile
20
+ @torch.no_grad()
21
+ def step(self):
22
+ rank = dist.get_rank()
23
+ world_size = dist.get_world_size()
24
+ reduce_scatter_futures: list[torch.Future] = []
25
+ all_reduce_futures: list[torch.Future] = []
26
+ grad_slices = []
27
+ for group in self.param_groups:
28
+ params: list[Tensor] = group["params"]
29
+ grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
30
+ for base_i in range(len(params)):
31
+ grad = params[base_i].grad
32
+ rank_size = grad.shape[0] // world_size
33
+ grad_slice = torch.empty_like(grad[:rank_size])
34
+ reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
35
+ grad_slices.append(grad_slice)
36
+
37
+ idx = 0
38
+ for group in self.param_groups:
39
+ beta1, beta2 = group['betas']
40
+ eps = group['eps']
41
+ wd = group['weight_decay']
42
+ params = group['params']
43
+ for base in range(len(params)):
44
+ reduce_scatter_futures[idx].wait()
45
+ p = params[base]
46
+ rank_size = p.shape[0] // world_size
47
+ p_slice = p[rank * rank_size:(rank + 1) * rank_size]
48
+ lr = group['lr'] * getattr(p, "lr_mul", 1.0)
49
+ state = self.state[p]
50
+ g_slice = grad_slices[idx]
51
+ # State init
52
+ if not state:
53
+ state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
54
+ state['exp_avg'] = torch.zeros_like(p_slice)
55
+ state['exp_avg_sq'] = torch.zeros_like(p_slice)
56
+ exp_avg = state['exp_avg']
57
+ exp_avg_sq = state['exp_avg_sq']
58
+ state['step'] += 1
59
+ t = state['step']
60
+ # weight decay
61
+ if wd != 0:
62
+ eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
63
+ p_slice.mul_(1 - eff_weight_decay)
64
+ # update running averages
65
+ exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
66
+ exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
67
+ # bias corrections
68
+ bias1 = 1 - beta1 ** t
69
+ bias2 = 1 - beta2 ** t
70
+ # compute step
71
+ denom = exp_avg_sq.sqrt().add_(eps)
72
+ step_size = lr * (torch.sqrt(bias2) / bias1)
73
+ update = exp_avg.div(denom).mul_(step_size)
74
+ p_slice.add_(other=update, alpha=-1.0)
75
+ idx += 1
76
+ all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
77
+ torch.futures.collect_all(all_reduce_futures).wait()
nanochat/checkpoint_manager.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving and loading model/optim/state checkpoints.
3
+ """
4
+ import os
5
+ import re
6
+ import glob
7
+ import json
8
+ import logging
9
+ import torch
10
+
11
+ from nanochat.common import get_base_dir
12
+ from nanochat.gpt import GPT, GPTConfig
13
+ from nanochat.tokenizer import get_tokenizer
14
+ from nanochat.common import setup_default_logging
15
+
16
+ # Set up logging
17
+ setup_default_logging()
18
+ logger = logging.getLogger(__name__)
19
+ def log0(message):
20
+ if int(os.environ.get('RANK', 0)) == 0:
21
+ logger.info(message)
22
+
23
+ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
24
+ assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
25
+ os.makedirs(checkpoint_dir, exist_ok=True)
26
+ # Save the model state (parameters)
27
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
28
+ torch.save(model_data, model_path)
29
+ log0(f"Saved model file to: {model_path}")
30
+ # Save the optimizer state (useful for SFT or any other fine-tuning)
31
+ if optimizer_data is not None:
32
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
33
+ torch.save(optimizer_data, optimizer_path)
34
+ log0(f"Saved optimizer file to: {optimizer_path}")
35
+ # Save the metadata dict as json
36
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
37
+ with open(meta_path, "w") as f:
38
+ json.dump(meta_data, f, indent=2)
39
+ log0(f"Saved metadata file to: {meta_path}")
40
+
41
+
42
+ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
43
+ # Load the model state
44
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
45
+ model_data = torch.load(model_path, map_location=device)
46
+ # Load the optimizer state if requested
47
+ optimizer_data = None
48
+ if load_optimizer:
49
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
50
+ optimizer_data = torch.load(optimizer_path, map_location=device)
51
+ # Load the metadata
52
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
53
+ with open(meta_path, "r") as f:
54
+ meta_data = json.load(f)
55
+ return model_data, optimizer_data, meta_data
56
+
57
+
58
+ def build_model(checkpoint_dir, step, device, phase):
59
+ """
60
+ A bunch of repetitive code to build a model from a given checkpoint.
61
+ Returns:
62
+ - base model - uncompiled, not wrapped in DDP
63
+ - tokenizer
64
+ - meta data saved during base model training
65
+ """
66
+ assert phase in ["train", "eval"], f"Invalid phase: {phase}"
67
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
68
+ # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
69
+ model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
70
+ model_config_kwargs = meta_data["model_config"]
71
+ log0(f"Building model with config: {model_config_kwargs}")
72
+ model_config = GPTConfig(**model_config_kwargs)
73
+ with torch.device("meta"):
74
+ model = GPT(model_config)
75
+ # Load the model state
76
+ model.to_empty(device=device)
77
+ model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
78
+ model.load_state_dict(model_data, strict=True, assign=True)
79
+ # Put the model in the right training phase / mode
80
+ if phase == "eval":
81
+ model.eval()
82
+ else:
83
+ model.train()
84
+ # Load the Tokenizer
85
+ tokenizer = get_tokenizer()
86
+ # Sanity check: compatibility between model and tokenizer
87
+ assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
88
+ return model, tokenizer, meta_data
89
+
90
+
91
+ def find_largest_model(checkpoint_dir):
92
+ # attempt to guess the model tag: take the biggest model available
93
+ model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
94
+ if not model_tags:
95
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
96
+ # 1) normally all model tags are of the form d<number>, try that first:
97
+ candidates = []
98
+ for model_tag in model_tags:
99
+ match = re.match(r"d(\d+)", model_tag)
100
+ if match:
101
+ model_depth = int(match.group(1))
102
+ candidates.append((model_depth, model_tag))
103
+ if candidates:
104
+ candidates.sort(key=lambda x: x[0], reverse=True)
105
+ return candidates[0][1]
106
+ # 2) if that failed, take the most recently updated model:
107
+ model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
108
+ return model_tags[0]
109
+
110
+
111
+ def find_last_step(checkpoint_dir):
112
+ # Look into checkpoint_dir and find model_<step>.pt with the highest step
113
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
114
+ if not checkpoint_files:
115
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
116
+ last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
117
+ return last_step
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # convenience functions that take into account nanochat's directory structure
121
+
122
+ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
123
+ if model_tag is None:
124
+ # guess the model tag by defaulting to the largest model
125
+ model_tag = find_largest_model(checkpoints_dir)
126
+ log0(f"No model tag provided, guessing model tag: {model_tag}")
127
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
128
+ if step is None:
129
+ # guess the step by defaulting to the last step
130
+ step = find_last_step(checkpoint_dir)
131
+ assert step is not None, f"No checkpoints found in {checkpoint_dir}"
132
+ # build the model
133
+ log0(f"Loading model from {checkpoint_dir} with step {step}")
134
+ model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
135
+ return model, tokenizer, meta_data
136
+
137
+ def load_model(source, *args, **kwargs):
138
+ model_dir = {
139
+ "base": "base_checkpoints",
140
+ "mid": "mid_checkpoints",
141
+ "sft": "chatsft_checkpoints",
142
+ "rl": "chatrl_checkpoints",
143
+ }[source]
144
+ base_dir = get_base_dir()
145
+ checkpoints_dir = os.path.join(base_dir, model_dir)
146
+ return load_model_from_dir(checkpoints_dir, *args, **kwargs)
nanochat/common.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for nanochat.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import logging
8
+ import torch
9
+ import torch.distributed as dist
10
+
11
+ class ColoredFormatter(logging.Formatter):
12
+ """Custom formatter that adds colors to log messages."""
13
+ # ANSI color codes
14
+ COLORS = {
15
+ 'DEBUG': '\033[36m', # Cyan
16
+ 'INFO': '\033[32m', # Green
17
+ 'WARNING': '\033[33m', # Yellow
18
+ 'ERROR': '\033[31m', # Red
19
+ 'CRITICAL': '\033[35m', # Magenta
20
+ }
21
+ RESET = '\033[0m'
22
+ BOLD = '\033[1m'
23
+ def format(self, record):
24
+ # Add color to the level name
25
+ levelname = record.levelname
26
+ if levelname in self.COLORS:
27
+ record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
28
+ # Format the message
29
+ message = super().format(record)
30
+ # Add color to specific parts of the message
31
+ if levelname == 'INFO':
32
+ # Highlight numbers and percentages
33
+ message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
34
+ message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
35
+ return message
36
+
37
+ def setup_default_logging():
38
+ handler = logging.StreamHandler()
39
+ handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
40
+ logging.basicConfig(
41
+ level=logging.INFO,
42
+ handlers=[handler]
43
+ )
44
+
45
+ setup_default_logging()
46
+ logger = logging.getLogger(__name__)
47
+
48
+ def get_base_dir():
49
+ # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
50
+ if os.environ.get("NANOCHAT_BASE_DIR"):
51
+ nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
52
+ else:
53
+ home_dir = os.path.expanduser("~")
54
+ cache_dir = os.path.join(home_dir, ".cache")
55
+ nanochat_dir = os.path.join(cache_dir, "nanochat")
56
+ os.makedirs(nanochat_dir, exist_ok=True)
57
+ return nanochat_dir
58
+
59
+ def print0(s="",**kwargs):
60
+ ddp_rank = int(os.environ.get('RANK', 0))
61
+ if ddp_rank == 0:
62
+ print(s, **kwargs)
63
+
64
+ def print_banner():
65
+ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
66
+ banner = """
67
+ █████ █████
68
+ ░░███ ░░███
69
+ ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
70
+ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░
71
+ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
72
+ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
73
+ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████
74
+ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
75
+ """
76
+ print0(banner)
77
+
78
+ def is_ddp():
79
+ # TODO is there a proper way
80
+ return int(os.environ.get('RANK', -1)) != -1
81
+
82
+ def get_dist_info():
83
+ if is_ddp():
84
+ assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
85
+ ddp_rank = int(os.environ['RANK'])
86
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
87
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
88
+ return True, ddp_rank, ddp_local_rank, ddp_world_size
89
+ else:
90
+ return False, 0, 0, 1
91
+
92
+ def compute_init():
93
+ """Basic initialization that we keep doing over and over, so make common."""
94
+
95
+ # CUDA is currently required
96
+ assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
97
+
98
+ # Reproducibility
99
+ torch.manual_seed(42)
100
+ torch.cuda.manual_seed(42)
101
+ # skipping full reproducibility for now, possibly investigate slowdown later
102
+ # torch.use_deterministic_algorithms(True)
103
+ # torch.backends.cudnn.deterministic = True
104
+ # torch.backends.cudnn.benchmark = False
105
+
106
+ # Precision
107
+ torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
108
+
109
+ # Distributed setup: Distributed Data Parallel (DDP), optional
110
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
111
+ if ddp:
112
+ device = torch.device("cuda", ddp_local_rank)
113
+ torch.cuda.set_device(device) # make "cuda" default to this device
114
+ dist.init_process_group(backend="nccl", device_id=device)
115
+ dist.barrier()
116
+ else:
117
+ device = torch.device("cuda")
118
+
119
+ if ddp_rank == 0:
120
+ logger.info(f"Distributed world size: {ddp_world_size}")
121
+
122
+ return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
123
+
124
+ def compute_cleanup():
125
+ """Companion function to compute_init, to clean things up before script exit"""
126
+ if is_ddp():
127
+ dist.destroy_process_group()
128
+
129
+ class DummyWandb:
130
+ """Useful if we wish to not use wandb but have all the same signatures"""
131
+ def __init__(self):
132
+ pass
133
+ def log(self, *args, **kwargs):
134
+ pass
135
+ def finish(self):
136
+ pass
nanochat/configurator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ from ast import literal_eval
20
+
21
+ def print0(s="",**kwargs):
22
+ ddp_rank = int(os.environ.get('RANK', 0))
23
+ if ddp_rank == 0:
24
+ print(s, **kwargs)
25
+
26
+ for arg in sys.argv[1:]:
27
+ if '=' not in arg:
28
+ # assume it's the name of a config file
29
+ assert not arg.startswith('--')
30
+ config_file = arg
31
+ print0(f"Overriding config with {config_file}:")
32
+ with open(config_file) as f:
33
+ print0(f.read())
34
+ exec(open(config_file).read())
35
+ else:
36
+ # assume it's a --key=value argument
37
+ assert arg.startswith('--')
38
+ key, val = arg.split('=')
39
+ key = key[2:]
40
+ if key in globals():
41
+ try:
42
+ # attempt to eval it it (e.g. if bool, number, or etc)
43
+ attempt = literal_eval(val)
44
+ except (SyntaxError, ValueError):
45
+ # if that goes wrong, just use the string
46
+ attempt = val
47
+ # ensure the types match ok
48
+ if globals()[key] is not None:
49
+ attempt_type = type(attempt)
50
+ default_type = type(globals()[key])
51
+ assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
52
+ # cross fingers
53
+ print0(f"Overriding: {key} = {attempt}")
54
+ globals()[key] = attempt
55
+ else:
56
+ raise ValueError(f"Unknown config key: {key}")
nanochat/core_eval.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for evaluating the CORE metric, as described in the DCLM paper.
3
+ https://arxiv.org/abs/2406.11794
4
+
5
+ TODOs:
6
+ - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
7
+ """
8
+ import random
9
+
10
+ from jinja2 import Template
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Prompt rendering utilities
16
+
17
+ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
18
+ """Render complete prompts for a multiple choice question"""
19
+ template_str = """
20
+ {%- for example in fewshot_examples -%}
21
+ {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
22
+
23
+ {% endfor -%}
24
+ {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
25
+ template = Template(template_str)
26
+ fewshot_examples = fewshot_examples or []
27
+ context = {
28
+ 'fewshot_examples': fewshot_examples,
29
+ 'continuation_delimiter': continuation_delimiter,
30
+ 'item': item
31
+ }
32
+ prompts = [template.render(choice=choice, **context) for choice in item['choices']]
33
+ return prompts
34
+
35
+
36
+ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
37
+ """Render complete prompts for a schema question"""
38
+ template_str = """
39
+ {%- for example in fewshot_examples -%}
40
+ {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
41
+
42
+ {% endfor -%}
43
+ {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
44
+ template = Template(template_str)
45
+ fewshot_examples = fewshot_examples or []
46
+ context = {
47
+ 'fewshot_examples': fewshot_examples,
48
+ 'continuation_delimiter': continuation_delimiter,
49
+ 'item': item
50
+ }
51
+ prompts = [template.render(context=context_option, **context)
52
+ for context_option in item['context_options']]
53
+ return prompts
54
+
55
+
56
+ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
57
+ """
58
+ Render complete prompt for a language modeling task.
59
+ Notice that we manually trim the context in the template,
60
+ which in some datasets seems to have trailing whitespace (which we don't want).
61
+ """
62
+ template_str = """
63
+ {%- for example in fewshot_examples -%}
64
+ {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
65
+
66
+ {% endfor -%}
67
+ {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
68
+ template = Template(template_str)
69
+ fewshot_examples = fewshot_examples or []
70
+ context = {
71
+ 'fewshot_examples': fewshot_examples,
72
+ 'continuation_delimiter': continuation_delimiter,
73
+ 'item': item
74
+ }
75
+ # Return two prompts: without and with the continuation
76
+ prompt_without = template.render(include_continuation=False, **context)
77
+ prompt_with = template.render(include_continuation=True, **context)
78
+ # Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
79
+ # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
80
+ # token in prompt_with), meaning we don't get a nice and clean prefix in the token space
81
+ # to detect the final continuation. Tokenizers...
82
+ prompt_without = prompt_without.strip()
83
+ return [prompt_without, prompt_with]
84
+
85
+
86
+ def find_common_length(token_sequences, direction='left'):
87
+ """
88
+ Find the length of the common prefix or suffix across token sequences
89
+ - direction: 'left' for prefix, 'right' for suffix
90
+ """
91
+ min_len = min(len(seq) for seq in token_sequences)
92
+ indices = {
93
+ 'left': range(min_len),
94
+ 'right': range(-1, -min_len-1, -1)
95
+ }[direction]
96
+ # Find the first position where the token sequences differ
97
+ for i, idx in enumerate(indices):
98
+ token = token_sequences[0][idx]
99
+ if not all(seq[idx] == token for seq in token_sequences):
100
+ return i
101
+ return min_len
102
+
103
+
104
+ def stack_sequences(tokens, pad_token_id):
105
+ """Stack up a list of token sequences, pad to longest on the right"""
106
+ bsz, seq_len = len(tokens), max(len(x) for x in tokens)
107
+ input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
108
+ for i, x in enumerate(tokens):
109
+ input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
110
+ return input_ids
111
+
112
+
113
+ def batch_sequences_mc(tokenizer, prompts):
114
+ # In multiple choice, contexts are the same but the continuation is different (common prefix)
115
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
116
+ # figure out the start and end of each continuation
117
+ answer_start_idx = find_common_length(tokens, direction='left')
118
+ start_indices = [answer_start_idx] * len(prompts)
119
+ end_indices = [len(x) for x in tokens]
120
+ return tokens, start_indices, end_indices
121
+
122
+
123
+ def batch_sequences_schema(tokenizer, prompts):
124
+ # In schema tasks, contexts vary but continuation is the same (common suffix)
125
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
126
+ # figure out the start and end of each context
127
+ suffix_length = find_common_length(tokens, direction='right')
128
+ end_indices = [len(x) for x in tokens]
129
+ start_indices = [ei - suffix_length for ei in end_indices]
130
+ return tokens, start_indices, end_indices
131
+
132
+
133
+ def batch_sequences_lm(tokenizer, prompts):
134
+ # In LM tasks, we have two prompts: without and with continuation
135
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
136
+ tokens_without, tokens_with = tokens
137
+ start_idx, end_idx = len(tokens_without), len(tokens_with)
138
+ assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
139
+ assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
140
+ # we only need the with continuation prompt in the LM task, i.e. batch size of 1
141
+ return [tokens_with], [start_idx], [end_idx]
142
+
143
+
144
+ @torch.no_grad()
145
+ def forward_model(model, input_ids):
146
+ """
147
+ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
148
+ The last column of losses is set to nan because we don't have autoregressive targets there.
149
+ """
150
+ batch_size, seq_len = input_ids.size()
151
+ outputs = model(input_ids)
152
+ # Roll the tensor to the left by one position to get the (autoregressive) target ids
153
+ target_ids = torch.roll(input_ids, shifts=-1, dims=1)
154
+ # Calculate cross entropy at all positions
155
+ losses = torch.nn.functional.cross_entropy(
156
+ outputs.view(batch_size * seq_len, -1),
157
+ target_ids.view(batch_size * seq_len),
158
+ reduction='none'
159
+ ).view(batch_size, seq_len)
160
+ # Set the last column to be nan because there is no autoregressive loss there
161
+ losses[:, -1] = float('nan')
162
+ # Get the argmax predictions at each position
163
+ predictions = outputs.argmax(dim=-1)
164
+ return losses, predictions
165
+
166
+
167
+ @torch.no_grad()
168
+ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
169
+ """Evaluate a single example, return True if correct, False otherwise"""
170
+ item = data[idx]
171
+ task_type = task_meta['task_type']
172
+ num_fewshot = task_meta['num_fewshot']
173
+ continuation_delimiter = task_meta['continuation_delimiter']
174
+
175
+ # Sample few-shot examples (excluding current item)
176
+ fewshot_examples = []
177
+ if num_fewshot > 0:
178
+ rng = random.Random(1234 + idx)
179
+ available_indices = [i for i in range(len(data)) if i != idx]
180
+ fewshot_indices = rng.sample(available_indices, num_fewshot)
181
+ fewshot_examples = [data[i] for i in fewshot_indices]
182
+
183
+ # Render prompts and batch sequences based on task type
184
+ if task_type == 'multiple_choice':
185
+ prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
186
+ tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
187
+ elif task_type == 'schema':
188
+ prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
189
+ tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
190
+ elif task_type == 'language_modeling':
191
+ prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
192
+ tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
193
+ else:
194
+ raise ValueError(f"Unsupported task type: {task_type}")
195
+
196
+ # Some models can't forward sequences beyond a certain length (e.g. GPT-2)
197
+ # In these cases, we have to truncate sequences to max length and adjust the indices
198
+ if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
199
+ max_tokens = model.max_seq_len
200
+ new_tokens, new_start_idxs, new_end_idxs = [], [], []
201
+ for t, s, e in zip(tokens, start_idxs, end_idxs):
202
+ if len(t) > max_tokens:
203
+ num_to_crop = len(t) - max_tokens
204
+ new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
205
+ new_start_idxs.append(s - num_to_crop) # shift the indices down
206
+ new_end_idxs.append(e - num_to_crop)
207
+ assert s - num_to_crop >= 0, "this should never happen right?"
208
+ assert e - num_to_crop >= 0, "this should never happen right?"
209
+ else:
210
+ new_tokens.append(t) # keep unchanged
211
+ new_start_idxs.append(s)
212
+ new_end_idxs.append(e)
213
+ tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
214
+
215
+ # Stack up all the sequences into a batch
216
+ pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
217
+ input_ids = stack_sequences(tokens, pad_token_id)
218
+ input_ids = input_ids.to(device)
219
+
220
+ # Forward the model, get the autoregressive loss and argmax prediction at each token
221
+ losses, predictions = forward_model(model, input_ids)
222
+
223
+ # See if the losses/predictions come out correctly
224
+ if task_type == 'language_modeling':
225
+ # language modeling task is currently always batch size 1
226
+ si = start_idxs[0]
227
+ ei = end_idxs[0]
228
+ # predictions[i] predict input_ids[i+1] autoregressively
229
+ predicted_tokens = predictions[0, si-1:ei-1]
230
+ actual_tokens = input_ids[0, si:ei]
231
+ is_correct = torch.all(predicted_tokens == actual_tokens).item()
232
+ elif task_type in ['multiple_choice', 'schema']:
233
+ # For MC/schema: find the option with lowest average loss
234
+ mean_losses = [losses[i, si-1:ei-1].mean().item()
235
+ for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
236
+ pred_idx = mean_losses.index(min(mean_losses))
237
+ is_correct = pred_idx == item['gold']
238
+ else:
239
+ raise ValueError(f"Unsupported task type: {task_type}")
240
+
241
+ return is_correct
242
+
243
+
244
+ def evaluate_task(model, tokenizer, data, device, task_meta):
245
+ """
246
+ This function is responsible for evaluating one task across many examples.
247
+ It also handles dispatch to all processes if the script is run with torchrun.
248
+ """
249
+ rank = dist.get_rank() if dist.is_initialized() else 0
250
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
251
+ correct = torch.zeros(len(data), dtype=torch.float32, device=device)
252
+ # stride the examples to each rank
253
+ for idx in range(rank, len(data), world_size):
254
+ is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
255
+ correct[idx] = float(is_correct)
256
+ # sync results across all the processes if running distributed
257
+ if world_size > 1:
258
+ dist.barrier()
259
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
260
+ # compute the mean
261
+ mean_correct = correct.mean().item()
262
+ return mean_correct
nanochat/dataloader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from nanochat.common import get_dist_info
6
+ from nanochat.dataset import parquets_iter_batched
7
+ from nanochat.tokenizer import get_tokenizer
8
+
9
+ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
10
+ """Stream pretraining text from parquet files, tokenize, yield training batches."""
11
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
12
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
13
+ needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
14
+ # get the tokenizer and the bos token
15
+ tokenizer = get_tokenizer()
16
+ bos_token = tokenizer.get_bos_token_id()
17
+ # scratch buffer holds the tokens for one iteration
18
+ token_buffer = deque() # we stream tokens on the right and pop from the left
19
+ scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
20
+
21
+ # infinite iterator over document batches
22
+ def document_batches():
23
+ while True:
24
+ # batch will iterate in group size of the parquet files, usually e.g. 1024 rows
25
+ for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
26
+ # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
27
+ for i in range(0, len(batch), tokenizer_batch_size):
28
+ yield batch[i:i+tokenizer_batch_size]
29
+ batches = document_batches()
30
+
31
+ batch_index = 0
32
+ while True:
33
+ # Accumulate enough tokens for one iteration before yielding.
34
+ while len(token_buffer) < needed_tokens:
35
+ doc_batch = next(batches)
36
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
37
+ for tokens in token_lists:
38
+ token_buffer.extend(tokens)
39
+ batch_index += 1
40
+ # Move tokens from the deque into the scratch buffer
41
+ for i in range(needed_tokens):
42
+ scratch[i] = token_buffer.popleft()
43
+ # Create the inputs/targets as 1D tensors
44
+ inputs_cpu = scratch[:-1].to(dtype=torch.int32)
45
+ targets_cpu = scratch[1:]
46
+ # Reshape to 2D and move to GPU async
47
+ inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
48
+ targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
49
+ yield inputs, targets
nanochat/dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The base/pretraining dataset is a set of parquet files.
3
+ This file contains utilities for:
4
+ - iterating over the parquet files and yielding documents from it
5
+ - download the files on demand if they are not on disk
6
+
7
+ For details of how the dataset was prepared, see `repackage_data_reference.py`.
8
+ """
9
+
10
+ import os
11
+ import argparse
12
+ import time
13
+ import requests
14
+ import pyarrow.parquet as pq
15
+ from multiprocessing import Pool
16
+
17
+ from nanochat.common import get_base_dir
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # The specifics of the current pretraining dataset
21
+
22
+ # The URL on the internet where the data is hosted and downloaded from on demand
23
+ BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
24
+ MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
25
+ index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
26
+ base_dir = get_base_dir()
27
+ DATA_DIR = os.path.join(base_dir, "base_data")
28
+ os.makedirs(DATA_DIR, exist_ok=True)
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # These functions are useful utilities to other modules, can/should be imported
32
+
33
+ def list_parquet_files(data_dir=None):
34
+ """ Looks into a data dir and returns full paths to all parquet files. """
35
+ data_dir = DATA_DIR if data_dir is None else data_dir
36
+ parquet_files = sorted([
37
+ f for f in os.listdir(data_dir)
38
+ if f.endswith('.parquet') and not f.endswith('.tmp')
39
+ ])
40
+ parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
41
+ return parquet_paths
42
+
43
+ def parquets_iter_batched(split, start=0, step=1):
44
+ """
45
+ Iterate through the dataset, in batches of underlying row_groups for efficiency.
46
+ - split can be "train" or "val". the last parquet file will be val.
47
+ - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
48
+ """
49
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
50
+ parquet_paths = list_parquet_files()
51
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
52
+ for filepath in parquet_paths:
53
+ pf = pq.ParquetFile(filepath)
54
+ for rg_idx in range(start, pf.num_row_groups, step):
55
+ rg = pf.read_row_group(rg_idx)
56
+ texts = rg.column('text').to_pylist()
57
+ yield texts
58
+
59
+ # -----------------------------------------------------------------------------
60
+ def download_single_file(index):
61
+ """ Downloads a single file index, with some backoff """
62
+
63
+ # Construct the local filepath for this file and skip if it already exists
64
+ filename = index_to_filename(index)
65
+ filepath = os.path.join(DATA_DIR, filename)
66
+ if os.path.exists(filepath):
67
+ print(f"Skipping {filepath} (already exists)")
68
+ return True
69
+
70
+ # Construct the remote URL for this file
71
+ url = f"{BASE_URL}/{filename}"
72
+ print(f"Downloading {filename}...")
73
+
74
+ # Download with retries
75
+ max_attempts = 5
76
+ for attempt in range(1, max_attempts + 1):
77
+ try:
78
+ response = requests.get(url, stream=True, timeout=30)
79
+ response.raise_for_status()
80
+ # Write to temporary file first
81
+ temp_path = filepath + f".tmp"
82
+ with open(temp_path, 'wb') as f:
83
+ for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
84
+ if chunk:
85
+ f.write(chunk)
86
+ # Move temp file to final location
87
+ os.rename(temp_path, filepath)
88
+ print(f"Successfully downloaded {filename}")
89
+ return True
90
+
91
+ except (requests.RequestException, IOError) as e:
92
+ print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
93
+ # Clean up any partial files
94
+ for path in [filepath + f".tmp", filepath]:
95
+ if os.path.exists(path):
96
+ try:
97
+ os.remove(path)
98
+ except:
99
+ pass
100
+ # Try a few times with exponential backoff: 2^attempt seconds
101
+ if attempt < max_attempts:
102
+ wait_time = 2 ** attempt
103
+ print(f"Waiting {wait_time} seconds before retry...")
104
+ time.sleep(wait_time)
105
+ else:
106
+ print(f"Failed to download {filename} after {max_attempts} attempts")
107
+ return False
108
+
109
+ return False
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
114
+ parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
115
+ parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
116
+ args = parser.parse_args()
117
+
118
+ num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
119
+ ids_to_download = list(range(num))
120
+ print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
121
+ print(f"Target directory: {DATA_DIR}")
122
+ print()
123
+ with Pool(processes=args.num_workers) as pool:
124
+ results = pool.map(download_single_file, ids_to_download)
125
+
126
+ # Report results
127
+ successful = sum(1 for success in results if success)
128
+ print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
nanochat/engine.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine for efficient inference of our models.
3
+
4
+ Everything works around token sequences:
5
+ - The user can send token sequences to the engine
6
+ - The engine returns the next token
7
+
8
+ Notes:
9
+ - The engine knows nothing about tokenization, it's purely token id sequences.
10
+
11
+ The whole thing is made as efficient as possible.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import signal
17
+ import warnings
18
+ from contextlib import contextmanager
19
+ from collections import deque
20
+ from nanochat.common import compute_init
21
+ from nanochat.checkpoint_manager import load_model
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Calculator tool helpers
25
+ @contextmanager
26
+ def timeout(duration, formula):
27
+ def timeout_handler(signum, frame):
28
+ raise Exception(f"'{formula}': timed out after {duration} seconds")
29
+
30
+ signal.signal(signal.SIGALRM, timeout_handler)
31
+ signal.alarm(duration)
32
+ yield
33
+ signal.alarm(0)
34
+
35
+ def eval_with_timeout(formula, max_time=3):
36
+ try:
37
+ with timeout(max_time, formula):
38
+ with warnings.catch_warnings():
39
+ warnings.simplefilter("ignore", SyntaxWarning)
40
+ return eval(formula)
41
+ except Exception as e:
42
+ signal.alarm(0)
43
+ # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
44
+ return None
45
+
46
+ def use_calculator(expr):
47
+ """Evaluate a math expression safely."""
48
+ expr = expr.replace(",", "")
49
+ if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
50
+ return None
51
+ if "**" in expr: # for now disallow power operator, could be very expensive
52
+ return None
53
+ return eval_with_timeout(expr)
54
+
55
+ # -----------------------------------------------------------------------------
56
+ class KVCache:
57
+ """
58
+ Works hand-in-hand with the GPT model to maintain the KV cache.
59
+ Note that the .pos advances automatically after the last layer of the Transformer inserts.
60
+ """
61
+
62
+ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
63
+ # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
64
+ self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
65
+ self.kv_cache = None
66
+ self.pos = 0 # current position in time in the cache
67
+
68
+ def reset(self):
69
+ self.pos = 0
70
+
71
+ def get_pos(self):
72
+ return self.pos
73
+
74
+ def prefill(self, other):
75
+ """
76
+ Prefill given another KV cache. Optionally expand along batch dim.
77
+ This is used when we do batch 1 prefill and then want to generate
78
+ multiple samples in parallel from there.
79
+ """
80
+ # 1) validate the shapes
81
+ assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
82
+ assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
83
+ for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
84
+ if ix in [0, 1, 3, 5]:
85
+ # num_layers, batch_size, num_heads, head_dim must match
86
+ assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}"
87
+ elif ix == 2:
88
+ # batch_size can be expanded
89
+ assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
90
+ elif ix == 4:
91
+ # seq_len: self must be longer than other
92
+ assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
93
+ # 2) initialize the cache
94
+ dtype, device = other.kv_cache.dtype, other.kv_cache.device
95
+ self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
96
+ # 3) copy the data over
97
+ self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
98
+ # 4) update the pos
99
+ self.pos = other.pos
100
+
101
+ def insert_kv(self, layer_idx, k, v):
102
+ # Lazy initialize the cache here because we need to know the dtype/device
103
+ if self.kv_cache is None:
104
+ self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
105
+ # Insert new keys/values to the cache and return the full cache so far
106
+ B, H, T_add, D = k.size()
107
+ t0, t1 = self.pos, self.pos + T_add
108
+ # Dynamically grow the cache if needed
109
+ if t1 > self.kv_cache.size(4):
110
+ t_needed = t1 + 1024 # as much as we need plus buffer of 1024
111
+ t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
112
+ current_shape = list(self.kv_cache.shape)
113
+ current_shape[4] = t_needed
114
+ self.kv_cache.resize_(current_shape)
115
+ # Insert k, v into the cache
116
+ self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
117
+ self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
118
+ # Return the full cached keys/values up to current position (as a view)
119
+ key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
120
+ value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
121
+ # Increment pos after the last layer of the Transformer processes
122
+ if layer_idx == self.kv_cache.size(0) - 1:
123
+ self.pos = t1
124
+ return key_view, value_view
125
+
126
+
127
+ # -----------------------------------------------------------------------------
128
+ @torch.inference_mode()
129
+ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
130
+ """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
131
+ assert temperature >= 0.0, "temperature must be non-negative"
132
+ if temperature == 0.0:
133
+ return torch.argmax(logits, dim=-1, keepdim=True)
134
+ if top_k is not None:
135
+ k = min(top_k, logits.size(-1))
136
+ vals, idx = torch.topk(logits, k, dim=-1)
137
+ vals = vals / temperature
138
+ probs = F.softmax(vals, dim=-1)
139
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
140
+ return idx.gather(1, choice)
141
+ else:
142
+ logits = logits / temperature
143
+ probs = F.softmax(logits, dim=-1)
144
+ return torch.multinomial(probs, num_samples=1, generator=rng)
145
+
146
+ # -----------------------------------------------------------------------------
147
+
148
+ class RowState:
149
+ # Per-row state tracking during generation
150
+ def __init__(self, current_tokens=None):
151
+ self.current_tokens = current_tokens or [] # Current token sequence for this row
152
+ self.forced_tokens = deque() # Queue of tokens to force inject
153
+ self.in_python_block = False # Whether we are inside a python block
154
+ self.python_expr_tokens = [] # Tokens of the current python expression
155
+ self.completed = False # Whether this row has completed generation
156
+
157
+ class Engine:
158
+
159
+ def __init__(self, model, tokenizer):
160
+ self.model = model
161
+ self.tokenizer = tokenizer # needed for tool use
162
+
163
+ @torch.inference_mode()
164
+ def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
165
+ """Same as generate, but does single prefill and then clones the KV cache."""
166
+ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
167
+ device = self.model.get_device()
168
+ rng = torch.Generator(device=device)
169
+ rng.manual_seed(seed)
170
+
171
+ # Get the special tokens we need to coordinate the tool use state machine
172
+ get_special = lambda s: self.tokenizer.encode_special(s)
173
+ python_start = get_special("<|python_start|>")
174
+ python_end = get_special("<|python_end|>")
175
+ output_start = get_special("<|output_start|>")
176
+ output_end = get_special("<|output_end|>")
177
+ assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
178
+ bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
179
+
180
+ # 1) Run a batch 1 prefill of the prompt tokens
181
+ m = self.model.config
182
+ kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
183
+ kv_cache_prefill = KVCache(
184
+ batch_size=1,
185
+ seq_len=len(tokens),
186
+ **kv_model_kwargs,
187
+ )
188
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
189
+ logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
190
+ logits = logits[:, -1, :]
191
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
192
+ sampled_tokens = next_ids[:, 0].tolist()
193
+
194
+ # 2) Replicate the KV cache for each sample/row
195
+ kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
196
+ kv_cache_decode = KVCache(
197
+ batch_size=num_samples,
198
+ seq_len=kv_length_hint,
199
+ **kv_model_kwargs,
200
+ )
201
+ kv_cache_decode.prefill(kv_cache_prefill)
202
+ del kv_cache_prefill # no need to keep this memory around
203
+
204
+ # 3) Initialize states for each sample
205
+ row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
206
+
207
+ # 4) Main generation loop
208
+ num_generated = 0
209
+ first_iteration = True
210
+ while True:
211
+ # Stop condition: we've reached max tokens
212
+ if max_tokens is not None and num_generated >= max_tokens:
213
+ break
214
+ # Stop condition: all rows are completed
215
+ if all(state.completed for state in row_states):
216
+ break
217
+
218
+ # Get sampled tokens - either from prefill or from forward pass
219
+ if first_iteration:
220
+ # Use the tokens we already sampled from prefill
221
+ sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
222
+ # TODO: we should sample a token for each row instead of broadcasting
223
+ first_iteration = False
224
+ else:
225
+ # Forward the model and get the next token for each row
226
+ logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
227
+ logits = logits[:, -1, :] # (B, vocab_size) at last time step
228
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
229
+ sampled_tokens = next_ids[:, 0].tolist()
230
+
231
+ # Process each row: choose the next token, update state, optional tool use
232
+ token_column = [] # contains the next token id along each row
233
+ token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
234
+ for i, state in enumerate(row_states):
235
+ # Select the next token in this row
236
+ is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
237
+ token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
238
+ next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
239
+ token_column.append(next_token)
240
+ # Update the state of this row to include the next token
241
+ state.current_tokens.append(next_token)
242
+ # On <|assistant_end|> or <|bos|>, mark the row as completed
243
+ if next_token == assistant_end or next_token == bos:
244
+ state.completed = True
245
+ # Handle tool logic
246
+ if next_token == python_start:
247
+ state.in_python_block = True
248
+ state.python_expr_tokens = []
249
+ elif next_token == python_end and state.in_python_block:
250
+ state.in_python_block = False
251
+ if state.python_expr_tokens:
252
+ expr = self.tokenizer.decode(state.python_expr_tokens)
253
+ result = use_calculator(expr)
254
+ if result is not None:
255
+ result_tokens = self.tokenizer.encode(str(result))
256
+ state.forced_tokens.append(output_start)
257
+ state.forced_tokens.extend(result_tokens)
258
+ state.forced_tokens.append(output_end)
259
+ state.python_expr_tokens = []
260
+ elif state.in_python_block:
261
+ state.python_expr_tokens.append(next_token)
262
+
263
+ # Yield the token column
264
+ yield token_column, token_masks
265
+ num_generated += 1
266
+ # Prepare ids for next iteration
267
+ ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
268
+
269
+ def generate_batch(self, tokens, num_samples=1, **kwargs):
270
+ """
271
+ Non-streaming batch generation that just returns the final token sequences.
272
+ Returns a list of token sequences (list of lists of ints).
273
+ Terminal tokens (assistant_end, bos) are not included in the results.
274
+ """
275
+ assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
276
+ bos = self.tokenizer.get_bos_token_id()
277
+ results = [tokens.copy() for _ in range(num_samples)]
278
+ masks = [[0] * len(tokens) for _ in range(num_samples)]
279
+ completed = [False] * num_samples
280
+ for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
281
+ for i, (token, mask) in enumerate(zip(token_column, token_masks)):
282
+ if not completed[i]:
283
+ if token == assistant_end or token == bos:
284
+ completed[i] = True
285
+ else:
286
+ results[i].append(token)
287
+ masks[i].append(mask)
288
+ # Stop if all rows are completed
289
+ if all(completed):
290
+ break
291
+ return results, masks
292
+
293
+
294
+ if __name__ == "__main__":
295
+ """
296
+ Quick inline test to make sure that the naive/slow model.generate function
297
+ is equivalent to the faster Engine.generate function here.
298
+ """
299
+ import time
300
+ # init compute
301
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
302
+ # load the model and tokenizer
303
+ model, tokenizer, meta = load_model("base", device, phase="eval")
304
+ bos_token_id = tokenizer.get_bos_token_id()
305
+ # common hyperparameters
306
+ kwargs = dict(max_tokens=64, temperature=0.0)
307
+ # set the starting prompt
308
+ prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
309
+ # generate the reference sequence using the model.generate() function
310
+ generated_tokens = []
311
+ torch.cuda.synchronize()
312
+ t0 = time.time()
313
+ stream = model.generate(prompt_tokens, **kwargs)
314
+ for token in stream:
315
+ generated_tokens.append(token)
316
+ chunk = tokenizer.decode([token])
317
+ print(chunk, end="", flush=True)
318
+ print()
319
+ torch.cuda.synchronize()
320
+ t1 = time.time()
321
+ print(f"Reference time: {t1 - t0:.2f}s")
322
+ reference_ids = generated_tokens
323
+ # generate tokens with Engine
324
+ generated_tokens = []
325
+ engine = Engine(model, tokenizer)
326
+ stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
327
+ torch.cuda.synchronize()
328
+ t0 = time.time()
329
+ for token_column, token_masks in stream:
330
+ token = token_column[0] # only print out the first row
331
+ generated_tokens.append(token)
332
+ chunk = tokenizer.decode([token])
333
+ print(chunk, end="", flush=True)
334
+ print()
335
+ torch.cuda.synchronize()
336
+ t1 = time.time()
337
+ print(f"Engine time: {t1 - t0:.2f}s")
338
+ # compare the two sequences
339
+ for i in range(len(reference_ids)):
340
+ if reference_ids[i] != generated_tokens[i]:
341
+ print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
342
+ break
343
+ print(f"Match: {reference_ids == generated_tokens}")
nanochat/execution.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sandboxed execution utilities for running Python code that comes out of an LLM.
3
+ Adapted from OpenAI HumanEval code:
4
+ https://github.com/openai/human-eval/blob/master/human_eval/execution.py
5
+
6
+ What is covered:
7
+ - Each execution runs in its own process (can be killed if it hangs or crashes)
8
+ - Execution is limited by a timeout to stop infinite loops
9
+ - Memory limits are enforced by default (256MB)
10
+ - stdout and stderr are captured and returned
11
+ - Code runs in a temporary directory that is deleted afterwards
12
+ - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
13
+
14
+ What is not covered:
15
+ - Not a true security sandbox
16
+ - Network access is not blocked (e.g. sockets could be opened)
17
+ - Python's dynamic features (e.g. ctypes) could bypass restrictions
18
+ - No kernel-level isolation (no seccomp, no containers, no virtualization)
19
+
20
+ Overall this sandbox is good for evaluation of generated code and protects against
21
+ accidental destructive behavior, but it is not safe against malicious adversarial code.
22
+ """
23
+
24
+ import contextlib
25
+ import faulthandler
26
+ import io
27
+ import multiprocessing
28
+ import os
29
+ import platform
30
+ import signal
31
+ import tempfile
32
+ from dataclasses import dataclass
33
+ from typing import Optional
34
+
35
+ # -----------------------------------------------------------------------------
36
+
37
+ @dataclass
38
+ class ExecutionResult:
39
+ """Result of executing Python code in a sandbox."""
40
+ success: bool
41
+ stdout: str
42
+ stderr: str
43
+ error: Optional[str] = None
44
+ timeout: bool = False
45
+ memory_exceeded: bool = False
46
+
47
+ def __repr__(self):
48
+ parts = []
49
+ parts.append(f"ExecutionResult(success={self.success}")
50
+ if self.timeout:
51
+ parts.append(", timeout=True")
52
+ if self.memory_exceeded:
53
+ parts.append(", memory_exceeded=True")
54
+ if self.error:
55
+ parts.append(f", error={self.error!r}")
56
+ if self.stdout:
57
+ parts.append(f", stdout={self.stdout!r}")
58
+ if self.stderr:
59
+ parts.append(f", stderr={self.stderr!r}")
60
+ parts.append(")")
61
+ return "".join(parts)
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def time_limit(seconds: float):
66
+ def signal_handler(signum, frame):
67
+ raise TimeoutException("Timed out!")
68
+
69
+ signal.setitimer(signal.ITIMER_REAL, seconds)
70
+ signal.signal(signal.SIGALRM, signal_handler)
71
+ try:
72
+ yield
73
+ finally:
74
+ signal.setitimer(signal.ITIMER_REAL, 0)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def capture_io():
79
+ """Capture stdout and stderr, and disable stdin."""
80
+ stdout_capture = io.StringIO()
81
+ stderr_capture = io.StringIO()
82
+ stdin_block = WriteOnlyStringIO()
83
+ with contextlib.redirect_stdout(stdout_capture):
84
+ with contextlib.redirect_stderr(stderr_capture):
85
+ with redirect_stdin(stdin_block):
86
+ yield stdout_capture, stderr_capture
87
+
88
+
89
+ @contextlib.contextmanager
90
+ def create_tempdir():
91
+ with tempfile.TemporaryDirectory() as dirname:
92
+ with chdir(dirname):
93
+ yield dirname
94
+
95
+
96
+ class TimeoutException(Exception):
97
+ pass
98
+
99
+
100
+ class WriteOnlyStringIO(io.StringIO):
101
+ """StringIO that throws an exception when it's read from"""
102
+
103
+ def read(self, *args, **kwargs):
104
+ raise IOError
105
+
106
+ def readline(self, *args, **kwargs):
107
+ raise IOError
108
+
109
+ def readlines(self, *args, **kwargs):
110
+ raise IOError
111
+
112
+ def readable(self, *args, **kwargs):
113
+ """Returns True if the IO object can be read."""
114
+ return False
115
+
116
+
117
+ class redirect_stdin(contextlib._RedirectStream): # type: ignore
118
+ _stream = "stdin"
119
+
120
+
121
+ @contextlib.contextmanager
122
+ def chdir(root):
123
+ if root == ".":
124
+ yield
125
+ return
126
+ cwd = os.getcwd()
127
+ os.chdir(root)
128
+ try:
129
+ yield
130
+ except BaseException as exc:
131
+ raise exc
132
+ finally:
133
+ os.chdir(cwd)
134
+
135
+
136
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
137
+ """
138
+ This disables various destructive functions and prevents the generated code
139
+ from interfering with the test (e.g. fork bomb, killing other processes,
140
+ removing filesystem files, etc.)
141
+
142
+ WARNING
143
+ This function is NOT a security sandbox. Untrusted code, including, model-
144
+ generated code, should not be blindly executed outside of one. See the
145
+ Codex paper for more information about OpenAI's code sandbox, and proceed
146
+ with caution.
147
+ """
148
+
149
+ if maximum_memory_bytes is not None:
150
+ import resource
151
+
152
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
153
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
154
+ if not platform.uname().system == "Darwin":
155
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
156
+
157
+ faulthandler.disable()
158
+
159
+ import builtins
160
+
161
+ builtins.exit = None
162
+ builtins.quit = None
163
+
164
+ import os
165
+
166
+ os.environ["OMP_NUM_THREADS"] = "1"
167
+
168
+ os.kill = None
169
+ os.system = None
170
+ os.putenv = None
171
+ os.remove = None
172
+ os.removedirs = None
173
+ os.rmdir = None
174
+ os.fchdir = None
175
+ os.setuid = None
176
+ os.fork = None
177
+ os.forkpty = None
178
+ os.killpg = None
179
+ os.rename = None
180
+ os.renames = None
181
+ os.truncate = None
182
+ os.replace = None
183
+ os.unlink = None
184
+ os.fchmod = None
185
+ os.fchown = None
186
+ os.chmod = None
187
+ os.chown = None
188
+ os.chroot = None
189
+ os.fchdir = None
190
+ os.lchflags = None
191
+ os.lchmod = None
192
+ os.lchown = None
193
+ os.getcwd = None
194
+ os.chdir = None
195
+
196
+ import shutil
197
+
198
+ shutil.rmtree = None
199
+ shutil.move = None
200
+ shutil.chown = None
201
+
202
+ import subprocess
203
+
204
+ subprocess.Popen = None # type: ignore
205
+
206
+ __builtins__["help"] = None
207
+
208
+ import sys
209
+
210
+ sys.modules["ipdb"] = None
211
+ sys.modules["joblib"] = None
212
+ sys.modules["resource"] = None
213
+ sys.modules["psutil"] = None
214
+ sys.modules["tkinter"] = None
215
+
216
+
217
+ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
218
+ """Execute code in a subprocess with safety guards. Results are written to result_dict."""
219
+ with create_tempdir():
220
+
221
+ # These system calls are needed when cleaning up tempdir.
222
+ import os
223
+ import shutil
224
+
225
+ rmtree = shutil.rmtree
226
+ rmdir = os.rmdir
227
+ chdir = os.chdir
228
+
229
+ # Disable functionalities that can make destructive changes to the test.
230
+ reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
231
+
232
+ # Default to failure
233
+ result_dict.update({
234
+ "success": False,
235
+ "stdout": "",
236
+ "stderr": "",
237
+ "timeout": False,
238
+ "memory_exceeded": False,
239
+ "error": None,
240
+ })
241
+
242
+ try:
243
+ exec_globals = {}
244
+ with capture_io() as (stdout_capture, stderr_capture):
245
+ with time_limit(timeout):
246
+ # WARNING
247
+ # This program exists to execute untrusted model-generated code. Although
248
+ # it is highly unlikely that model-generated code will do something overtly
249
+ # malicious in response to this test suite, model-generated code may act
250
+ # destructively due to a lack of model capability or alignment.
251
+ # Users are strongly encouraged to sandbox this evaluation suite so that it
252
+ # does not perform destructive actions on their host or network. For more
253
+ # information on how OpenAI sandboxes its code, see the accompanying paper.
254
+ # Once you have read this disclaimer and taken appropriate precautions,
255
+ # uncomment the following line and proceed at your own risk:
256
+ exec(code, exec_globals)
257
+
258
+ result_dict.update({
259
+ "success": True,
260
+ "stdout": stdout_capture.getvalue(),
261
+ "stderr": stderr_capture.getvalue(),
262
+ })
263
+
264
+ except TimeoutException:
265
+ result_dict.update({
266
+ "timeout": True,
267
+ "error": "Execution timed out",
268
+ })
269
+
270
+ except MemoryError as e:
271
+ result_dict.update({
272
+ "memory_exceeded": True,
273
+ "error": f"Memory limit exceeded: {e}",
274
+ })
275
+
276
+ except BaseException as e:
277
+ result_dict.update({
278
+ "error": f"{type(e).__name__}: {e}",
279
+ })
280
+
281
+ # Needed for cleaning up.
282
+ shutil.rmtree = rmtree
283
+ os.rmdir = rmdir
284
+ os.chdir = chdir
285
+
286
+
287
+ def execute_code(
288
+ code: str,
289
+ timeout: float = 5.0, # 5 seconds default
290
+ maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
291
+ ) -> ExecutionResult:
292
+ """
293
+ Execute Python code in a sandboxed environment.
294
+
295
+ Args:
296
+ code: Python code to execute as a string
297
+ timeout: Maximum execution time in seconds (default: 5.0)
298
+ maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
299
+
300
+ Returns:
301
+ ExecutionResult with success status, stdout/stderr, and error information
302
+
303
+ Example:
304
+ >>> result = execute_code("print('hello world')")
305
+ >>> result.success
306
+ True
307
+ >>> result.stdout
308
+ 'hello world\\n'
309
+ """
310
+
311
+ manager = multiprocessing.Manager()
312
+ result_dict = manager.dict()
313
+
314
+ p = multiprocessing.Process(
315
+ target=_unsafe_execute,
316
+ args=(code, timeout, maximum_memory_bytes, result_dict)
317
+ )
318
+ p.start()
319
+ p.join(timeout=timeout + 1)
320
+
321
+ if p.is_alive():
322
+ p.kill()
323
+ return ExecutionResult(
324
+ success=False,
325
+ stdout="",
326
+ stderr="",
327
+ error="Execution timed out (process killed)",
328
+ timeout=True,
329
+ memory_exceeded=False,
330
+ )
331
+
332
+ if not result_dict:
333
+ return ExecutionResult(
334
+ success=False,
335
+ stdout="",
336
+ stderr="",
337
+ error="Execution failed (no result returned)",
338
+ timeout=True,
339
+ memory_exceeded=False,
340
+ )
341
+
342
+ return ExecutionResult(
343
+ success=result_dict["success"],
344
+ stdout=result_dict["stdout"],
345
+ stderr=result_dict["stderr"],
346
+ error=result_dict["error"],
347
+ timeout=result_dict["timeout"],
348
+ memory_exceeded=result_dict["memory_exceeded"],
349
+ )
350
+
nanochat/gpt.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model (rewrite, a lot simpler)
3
+ Notable features:
4
+ - rotary embeddings (and no positional embeddings)
5
+ - QK norm
6
+ - untied weights for token embedding and lm_head
7
+ - relu^2 activation in MLP
8
+ - norm after token embedding
9
+ - no learnable params in rmsnorm
10
+ - no bias in linear layers
11
+ - Multi-Query Attention (MQA) support for more efficient inference
12
+ """
13
+
14
+ import math
15
+ from functools import partial
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from nanochat.common import get_dist_info, print0
23
+ from nanochat.muon import Muon, DistMuon
24
+ from nanochat.adamw import DistAdamW
25
+
26
+ @dataclass
27
+ class GPTConfig:
28
+ sequence_len: int = 1024
29
+ vocab_size: int = 50304
30
+ n_layer: int = 12
31
+ n_head: int = 6 # number of query heads
32
+ n_kv_head: int = 6 # number of key/value heads (MQA)
33
+ n_embd: int = 768
34
+
35
+
36
+ def norm(x):
37
+ # Purely functional rmsnorm with no learnable params
38
+ return F.rms_norm(x, (x.size(-1),))
39
+
40
+
41
+ def apply_rotary_emb(x, cos, sin):
42
+ assert x.ndim == 4 # multihead attention
43
+ d = x.shape[3] // 2
44
+ x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
45
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
46
+ y2 = x1 * (-sin) + x2 * cos
47
+ out = torch.cat([y1, y2], 3) # re-assemble
48
+ out = out.to(x.dtype) # ensure input/output dtypes match
49
+ return out
50
+
51
+
52
+ def repeat_kv(x, n_rep):
53
+ """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
54
+ if n_rep == 1:
55
+ return x
56
+ bs, n_kv_heads, slen, head_dim = x.shape
57
+ return (
58
+ x[:, :, None, :, :]
59
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
60
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
61
+ )
62
+
63
+
64
+ class CausalSelfAttention(nn.Module):
65
+ def __init__(self, config, layer_idx):
66
+ super().__init__()
67
+ self.layer_idx = layer_idx
68
+ self.n_head = config.n_head
69
+ self.n_kv_head = config.n_kv_head
70
+ self.n_embd = config.n_embd
71
+ self.head_dim = self.n_embd // self.n_head
72
+ assert self.n_embd % self.n_head == 0
73
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
74
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
75
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
76
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
77
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
78
+
79
+ def forward(self, x, cos_sin, kv_cache):
80
+ B, T, C = x.size()
81
+
82
+ # Project the input to get queries, keys, and values
83
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
84
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
85
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
86
+
87
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
88
+ cos, sin = cos_sin
89
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
90
+ q, k = norm(q), norm(k) # QK norm
91
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
92
+
93
+ # Apply KV cache: insert current k,v into cache, get the full view so far
94
+ if kv_cache is not None:
95
+ k, v = kv_cache.insert_kv(self.layer_idx, k, v)
96
+ Tq = q.size(2) # number of queries in this forward pass
97
+ Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
98
+
99
+ # Apply MQA: replicate the key/value heads for each query head
100
+ nrep = self.n_head // self.n_kv_head
101
+ k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
102
+
103
+ # Attention: queries attend to keys/values autoregressively. A few cases to handle:
104
+ if kv_cache is None or Tq == Tk:
105
+ # During training (no KV cache), attend as usual with causal attention
106
+ # And even if there is KV cache, we can still use this simple version when Tq == Tk
107
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
108
+ elif Tq == 1:
109
+ # During inference but with a single query in this forward pass:
110
+ # The query has to attend to all the keys/values in the cache
111
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
112
+ else:
113
+ # During inference AND we have a chunk of queries in this forward pass:
114
+ # First, each query attends to all the cached keys/values (i.e. full prefix)
115
+ attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
116
+ prefix_len = Tk - Tq
117
+ if prefix_len > 0: # can't be negative but could be zero
118
+ attn_mask[:, :prefix_len] = True
119
+ # Then, causal attention within this chunk
120
+ attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
121
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
122
+
123
+ # Re-assemble the heads side by side and project back to residual stream
124
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
125
+ y = self.c_proj(y)
126
+ return y
127
+
128
+
129
+ class MLP(nn.Module):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
133
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
134
+
135
+ def forward(self, x):
136
+ x = self.c_fc(x)
137
+ x = F.relu(x).square()
138
+ x = self.c_proj(x)
139
+ return x
140
+
141
+
142
+ class Block(nn.Module):
143
+ def __init__(self, config, layer_idx):
144
+ super().__init__()
145
+ self.attn = CausalSelfAttention(config, layer_idx)
146
+ self.mlp = MLP(config)
147
+
148
+ def forward(self, x, cos_sin, kv_cache):
149
+ x = x + self.attn(norm(x), cos_sin, kv_cache)
150
+ x = x + self.mlp(norm(x))
151
+ return x
152
+
153
+
154
+ class GPT(nn.Module):
155
+ def __init__(self, config):
156
+ super().__init__()
157
+ self.config = config
158
+ self.transformer = nn.ModuleDict({
159
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
160
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
161
+ })
162
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
163
+ # To support meta device initialization, we init the rotary embeddings here, but it's fake
164
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
165
+ # so let's just over-compute them, but assert fail if we ever reach that amount.
166
+ # In the future we can dynamically grow the cache, for now it's fine.
167
+ self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
168
+ head_dim = config.n_embd // config.n_head
169
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
170
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
171
+ self.register_buffer("sin", sin, persistent=False)
172
+ # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
173
+ self.transformer.wte.to(dtype=torch.bfloat16)
174
+
175
+ def init_weights(self):
176
+ self.apply(self._init_weights)
177
+ # zero out classifier weights
178
+ torch.nn.init.zeros_(self.lm_head.weight)
179
+ # zero out c_proj weights in all blocks
180
+ for block in self.transformer.h:
181
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
182
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
183
+ # init the rotary embeddings
184
+ head_dim = self.config.n_embd // self.config.n_head
185
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
186
+ self.cos, self.sin = cos, sin
187
+
188
+ def _init_weights(self, module):
189
+ if isinstance(module, nn.Linear):
190
+ # https://arxiv.org/pdf/2310.17813
191
+ fan_out = module.weight.size(0)
192
+ fan_in = module.weight.size(1)
193
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
194
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
195
+ if module.bias is not None:
196
+ torch.nn.init.zeros_(module.bias)
197
+ elif isinstance(module, nn.Embedding):
198
+ torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
199
+
200
+ # TODO: bump base theta more, e.g. 100K is more common more recently
201
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
202
+ # autodetect the device from model embeddings
203
+ if device is None:
204
+ device = self.transformer.wte.weight.device
205
+ # stride the channels
206
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
207
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
208
+ # stride the time steps
209
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
210
+ # calculate the rotation frequencies at each (time, channel) pair
211
+ freqs = torch.outer(t, inv_freq)
212
+ cos, sin = freqs.cos(), freqs.sin()
213
+ cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
214
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
215
+ return cos, sin
216
+
217
+ def get_device(self):
218
+ return self.transformer.wte.weight.device
219
+
220
+ def estimate_flops(self):
221
+ """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
222
+ nparams = sum(p.numel() for p in self.parameters())
223
+ nparams_embedding = self.transformer.wte.weight.numel()
224
+ l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
225
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
226
+ return num_flops_per_token
227
+
228
+ def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
229
+ model_dim = self.config.n_embd
230
+ ddp, rank, local_rank, world_size = get_dist_info()
231
+ # Separate out all parameters into 3 groups (matrix, embedding, lm_head)
232
+ matrix_params = list(self.transformer.h.parameters())
233
+ embedding_params = list(self.transformer.wte.parameters())
234
+ lm_head_params = list(self.lm_head.parameters())
235
+ assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
236
+ # Create the AdamW optimizer for the embedding and lm_head
237
+ # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
238
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
239
+ if rank == 0:
240
+ print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
241
+ adam_groups = [
242
+ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
243
+ dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
244
+ ]
245
+ adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
246
+ AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
247
+ adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
248
+ # Create the Muon optimizer for the linear layers
249
+ muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
250
+ MuonFactory = DistMuon if ddp else Muon
251
+ muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
252
+ # Combine them the two optimizers into one list
253
+ optimizers = [adamw_optimizer, muon_optimizer]
254
+ for opt in optimizers:
255
+ for group in opt.param_groups:
256
+ group["initial_lr"] = group["lr"]
257
+ return optimizers
258
+
259
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
260
+ B, T = idx.size()
261
+
262
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
263
+ assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
264
+ assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
265
+ assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
266
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
267
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
268
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
269
+
270
+ # Forward the trunk of the Transformer
271
+ x = self.transformer.wte(idx)
272
+ x = norm(x)
273
+ for block in self.transformer.h:
274
+ x = block(x, cos_sin, kv_cache)
275
+ x = norm(x)
276
+
277
+ # Forward the lm_head (compute logits)
278
+ softcap = 15
279
+ if targets is not None:
280
+ # training mode: compute and return the loss
281
+ # TODO: experiment with Liger Kernels / chunked cross-entropy etc.
282
+ logits = self.lm_head(x)
283
+ logits = softcap * torch.tanh(logits / softcap) # logits softcap
284
+ logits = logits.float() # use tf32/fp32 for logits
285
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
286
+ return loss
287
+ else:
288
+ # inference mode: compute and return the logits
289
+ logits = self.lm_head(x)
290
+ logits = softcap * torch.tanh(logits / softcap) # logits softcap
291
+ return logits
292
+
293
+ @torch.inference_mode()
294
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
295
+ """
296
+ Naive autoregressive streaming inference.
297
+ To make it super simple, let's assume:
298
+ - batch size is 1
299
+ - ids and the yielded tokens are simple Python lists and ints
300
+ """
301
+ assert isinstance(tokens, list)
302
+ device = self.get_device()
303
+ rng = None
304
+ if temperature > 0:
305
+ rng = torch.Generator(device=device)
306
+ rng.manual_seed(seed)
307
+ ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
308
+ for _ in range(max_tokens):
309
+ logits = self.forward(ids) # (B, T, vocab_size)
310
+ logits = logits[:, -1, :] # (B, vocab_size)
311
+ if top_k is not None:
312
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
313
+ logits[logits < v[:, [-1]]] = -float('Inf')
314
+ if temperature > 0:
315
+ logits = logits / temperature
316
+ probs = F.softmax(logits, dim=-1)
317
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
318
+ else:
319
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
320
+ ids = torch.cat((ids, next_ids), dim=1)
321
+ token = next_ids.item()
322
+ yield token
nanochat/logo.svg ADDED
nanochat/loss_eval.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A number of functions that help with evaluating a base model.
3
+ """
4
+ import math
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ @torch.no_grad()
9
+ def evaluate_bpb(model, batches, steps, token_bytes):
10
+ """
11
+ Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
12
+ which is a tokenization vocab size-indepedent metric, meaning you are still comparing
13
+ apples:apples if you change the vocab size. The way this works is that instead of just
14
+ calculating the average loss as usual, you calculate the sum loss, and indepependently
15
+ also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
16
+ the number of bytes that the target tokens represent.
17
+
18
+ The added complexity is so that:
19
+ 1) All "normal" tokens are normalized by the length of the token in bytes
20
+ 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
21
+ 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
22
+
23
+ In addition to evaluate_loss, we need the token_bytes tensor:
24
+ It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
25
+ each token id, or 0 if the token is to not be counted (e.g. special tokens).
26
+ """
27
+ # record the losses
28
+ total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
29
+ total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
30
+ batch_iter = iter(batches)
31
+ for _ in range(steps):
32
+ x, y = next(batch_iter)
33
+ loss2d = model(x, y, loss_reduction='none') # (B, T)
34
+ loss2d = loss2d.view(-1) # flatten
35
+ y = y.view(-1) # flatten
36
+ if (y < 0).any():
37
+ # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
38
+ # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
39
+ valid = y >= 0
40
+ y_safe = torch.where(valid, y, torch.zeros_like(y))
41
+ # map valid targets to their byte length; ignored targets contribute 0 bytes
42
+ num_bytes2d = torch.where(
43
+ valid,
44
+ token_bytes[y_safe],
45
+ torch.zeros_like(y, dtype=token_bytes.dtype)
46
+ )
47
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
48
+ total_bytes += num_bytes2d.sum()
49
+ else:
50
+ # fast path: no ignored targets, safe to index directly
51
+ num_bytes2d = token_bytes[y]
52
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
53
+ total_bytes += num_bytes2d.sum()
54
+ # sum reduce across all ranks
55
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
56
+ if world_size > 1:
57
+ dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
58
+ dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
59
+ # move both to cpu, calculate bpb and return
60
+ total_nats = total_nats.item()
61
+ total_bytes = total_bytes.item()
62
+ bpb = total_nats / (math.log(2) * total_bytes)
63
+ return bpb
nanochat/muon.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Muon optimizer from Keller et al.
3
+ Also a lot of borrowing of ideas from modded-nanogpt.
4
+ """
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.distributed as dist
8
+
9
+ @torch.compile
10
+ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
11
+ """
12
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
13
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
14
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
15
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
16
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
17
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
18
+ performance at all relative to UV^T, where USV^T = G is the SVD.
19
+ """
20
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
21
+ a, b, c = (3.4445, -4.7750, 2.0315)
22
+ X = G.bfloat16()
23
+ if G.size(-2) > G.size(-1):
24
+ X = X.mT
25
+
26
+ # Ensure spectral norm is at most 1
27
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
28
+ # Perform the NS iterations
29
+ for _ in range(steps):
30
+ A = X @ X.mT
31
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
32
+ X = a * X + B @ X
33
+
34
+ if G.size(-2) > G.size(-1):
35
+ X = X.mT
36
+ return X
37
+
38
+ class Muon(torch.optim.Optimizer):
39
+ """
40
+ Muon - MomentUm Orthogonalized by Newton-schulz
41
+
42
+ https://kellerjordan.github.io/posts/muon/
43
+
44
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
45
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
46
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
47
+ the advantage that it can be stably run in bfloat16 on the GPU.
48
+
49
+ Some warnings:
50
+ - This optimizer should not be used for the embedding layer, the final fully connected layer,
51
+ or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
52
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
53
+
54
+ Arguments:
55
+ lr: The learning rate used by the internal SGD.
56
+ momentum: The momentum used by the internal SGD.
57
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
58
+ ns_steps: The number of Newton-Schulz iteration steps to use.
59
+ """
60
+ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
61
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
62
+ params: list[Tensor] = [*params]
63
+ param_groups = []
64
+ for size in {p.numel() for p in params}:
65
+ group = dict(params=[p for p in params if p.numel() == size])
66
+ param_groups.append(group)
67
+ super().__init__(param_groups, defaults)
68
+
69
+ @torch.no_grad()
70
+ def step(self):
71
+ for group in self.param_groups:
72
+ params: list[Tensor] = group["params"]
73
+ for p in params:
74
+ g = p.grad
75
+ assert g is not None
76
+ state = self.state[p]
77
+ if "momentum_buffer" not in state:
78
+ state["momentum_buffer"] = torch.zeros_like(g)
79
+ buf: Tensor = state["momentum_buffer"]
80
+ buf.lerp_(g, 1 - group["momentum"])
81
+ g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
82
+ g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
83
+ p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
84
+
85
+
86
+ class DistMuon(torch.optim.Optimizer):
87
+ """
88
+ Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
89
+ finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
90
+ - reduce_scatter(AVG) for gradient averaging
91
+ - all_gather to replicate updated weights
92
+
93
+ Notes:
94
+ * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
95
+ params like embeddings or scalars.
96
+ * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
97
+ by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
98
+ consolidate states beforehand.
99
+
100
+ Args:
101
+ params: iterable of Tensors
102
+ lr: learning rate
103
+ momentum: momentum coefficient in [0,1)
104
+ nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
105
+ ns_steps: number of Newton–Schulz iterations for the orthogonalization
106
+ """
107
+ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
108
+ nesterov: bool = True, ns_steps: int = 5):
109
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
110
+ params = list(params)
111
+ assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
112
+ rank = dist.get_rank()
113
+ # Group all parameters by their shape
114
+ shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
115
+ param_groups = []
116
+ for shape in shapes:
117
+ group_params = [p for p in params if p.shape == shape]
118
+ device, dtype = group_params[0].device, group_params[0].dtype
119
+ assert all(p.device == device for p in group_params)
120
+ assert all(p.dtype == dtype for p in group_params)
121
+ if rank == 0:
122
+ print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
123
+ param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
124
+ super().__init__(param_groups, defaults)
125
+
126
+ @torch.no_grad()
127
+ def step(self):
128
+ rank = dist.get_rank()
129
+ world_size = dist.get_world_size()
130
+
131
+ # Ensure all grads exist
132
+ assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
133
+
134
+ # Kick off all the reduce scatter operations to average up the gradients across all ranks
135
+ all_reduce_futures = []
136
+ for group in self.param_groups:
137
+ params = group["params"]
138
+ zero_buffer = group["zero_buffer"]
139
+ # Go through params in groups of world_size.
140
+ for base_i in range(0, len(params), world_size):
141
+ # The compute owner of each param is rank i % world_size
142
+ owner_idx = base_i + rank
143
+ # each rank stacks up its chunk of world_size params into a list
144
+ rs_input = [p.grad for p in params[base_i:base_i + world_size]]
145
+ # pad rs_input with the zero buffer to complete the group
146
+ rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
147
+ # the output buffer gets strided across the group based on the rank
148
+ rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
149
+ # reduce scatter the gradients within this group of world_size params
150
+ work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
151
+ all_reduce_futures.append(work)
152
+
153
+ # Now each rank computes the update and gathers
154
+ future_idx = 0
155
+ all_gather_futures = []
156
+ for group in self.param_groups:
157
+ params = group["params"]
158
+ zero_buffer = group["zero_buffer"]
159
+ # Go through params in groups of world_size.
160
+ for base_i in range(0, len(params), world_size):
161
+ # The compute owner of each param is rank i % world_size
162
+ owner_idx = base_i + rank # calculate the index of the param that this rank owns
163
+ # Wait for the reduce scatter to complete
164
+ all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
165
+ future_idx += 1
166
+ # Owner computes the Muon update, result is in its param
167
+ if owner_idx < len(params):
168
+ p = params[owner_idx]
169
+ g = p.grad # now averaged across ranks
170
+ state = self.state[p]
171
+ if "momentum_buffer" not in state:
172
+ state["momentum_buffer"] = torch.zeros_like(g)
173
+ buf: Tensor = state["momentum_buffer"]
174
+ buf.lerp_(g, 1.0 - group["momentum"])
175
+ g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
176
+ g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
177
+ scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
178
+ p.add_(g, alpha=-group["lr"] * scale)
179
+ # Replicate updated parameters to all ranks
180
+ ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
181
+ ag_output = params[base_i:base_i + world_size]
182
+ ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
183
+ work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
184
+ all_gather_futures.append(work)
185
+
186
+ # Wait for all work to finish
187
+ torch.futures.collect_all(all_gather_futures).wait()
nanochat/report.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for generating training report cards. More messy code than usual, will fix.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ import socket
10
+ import datetime
11
+ import platform
12
+ import psutil
13
+ import torch
14
+
15
+ def run_command(cmd):
16
+ """Run a shell command and return output, or None if it fails."""
17
+ try:
18
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
19
+ if result.returncode == 0:
20
+ return result.stdout.strip()
21
+ return None
22
+ except:
23
+ return None
24
+
25
+ def get_git_info():
26
+ """Get current git commit, branch, and dirty status."""
27
+ info = {}
28
+ info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
29
+ info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
30
+
31
+ # Check if repo is dirty (has uncommitted changes)
32
+ status = run_command("git status --porcelain")
33
+ info['dirty'] = bool(status) if status is not None else False
34
+
35
+ # Get commit message
36
+ info['message'] = run_command("git log -1 --pretty=%B") or ""
37
+ info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
38
+
39
+ return info
40
+
41
+ def get_gpu_info():
42
+ """Get GPU information."""
43
+ if not torch.cuda.is_available():
44
+ return {"available": False}
45
+
46
+ num_devices = torch.cuda.device_count()
47
+ info = {
48
+ "available": True,
49
+ "count": num_devices,
50
+ "names": [],
51
+ "memory_gb": []
52
+ }
53
+
54
+ for i in range(num_devices):
55
+ props = torch.cuda.get_device_properties(i)
56
+ info["names"].append(props.name)
57
+ info["memory_gb"].append(props.total_memory / (1024**3))
58
+
59
+ # Get CUDA version
60
+ info["cuda_version"] = torch.version.cuda or "unknown"
61
+
62
+ return info
63
+
64
+ def get_system_info():
65
+ """Get system information."""
66
+ info = {}
67
+
68
+ # Basic system info
69
+ info['hostname'] = socket.gethostname()
70
+ info['platform'] = platform.system()
71
+ info['python_version'] = platform.python_version()
72
+ info['torch_version'] = torch.__version__
73
+
74
+ # CPU and memory
75
+ info['cpu_count'] = psutil.cpu_count(logical=False)
76
+ info['cpu_count_logical'] = psutil.cpu_count(logical=True)
77
+ info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
78
+
79
+ # User and environment
80
+ info['user'] = os.environ.get('USER', 'unknown')
81
+ info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
82
+ info['working_dir'] = os.getcwd()
83
+
84
+ return info
85
+
86
+ def estimate_cost(gpu_info, runtime_hours=None):
87
+ """Estimate training cost based on GPU type and runtime."""
88
+
89
+ # Rough pricing, from Lambda Cloud
90
+ default_rate = 2.0
91
+ gpu_hourly_rates = {
92
+ "H100": 3.00,
93
+ "A100": 1.79,
94
+ "V100": 0.55,
95
+ }
96
+
97
+ if not gpu_info.get("available"):
98
+ return None
99
+
100
+ # Try to identify GPU type from name
101
+ hourly_rate = None
102
+ gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
103
+ for gpu_type, rate in gpu_hourly_rates.items():
104
+ if gpu_type in gpu_name:
105
+ hourly_rate = rate * gpu_info["count"]
106
+ break
107
+
108
+ if hourly_rate is None:
109
+ hourly_rate = default_rate * gpu_info["count"] # Default estimate
110
+
111
+ return {
112
+ "hourly_rate": hourly_rate,
113
+ "gpu_type": gpu_name,
114
+ "estimated_total": hourly_rate * runtime_hours if runtime_hours else None
115
+ }
116
+
117
+ def generate_header():
118
+ """Generate the header for a training report."""
119
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
120
+
121
+ git_info = get_git_info()
122
+ gpu_info = get_gpu_info()
123
+ sys_info = get_system_info()
124
+ cost_info = estimate_cost(gpu_info)
125
+
126
+ header = f"""# nanochat training report
127
+
128
+ Generated: {timestamp}
129
+
130
+ ## Environment
131
+
132
+ ### Git Information
133
+ - Branch: {git_info['branch']}
134
+ - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
135
+ - Message: {git_info['message']}
136
+
137
+ ### Hardware
138
+ - Platform: {sys_info['platform']}
139
+ - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
140
+ - Memory: {sys_info['memory_gb']:.1f} GB
141
+ """
142
+
143
+ if gpu_info.get("available"):
144
+ gpu_names = ", ".join(set(gpu_info["names"]))
145
+ total_vram = sum(gpu_info["memory_gb"])
146
+ header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
147
+ - GPU Memory: {total_vram:.1f} GB total
148
+ - CUDA Version: {gpu_info['cuda_version']}
149
+ """
150
+ else:
151
+ header += "- GPUs: None available\n"
152
+
153
+ if cost_info and cost_info["hourly_rate"] > 0:
154
+ header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
155
+
156
+ header += f"""
157
+ ### Software
158
+ - Python: {sys_info['python_version']}
159
+ - PyTorch: {sys_info['torch_version']}
160
+
161
+ """
162
+
163
+ # bloat metrics: package all of the source code and assess its weight
164
+ packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
165
+ num_chars = len(packaged)
166
+ num_lines = len(packaged.split('\n'))
167
+ num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
168
+ num_tokens = num_chars // 4 # assume approximately 4 chars per token
169
+
170
+ # count dependencies via uv.lock
171
+ uv_lock_lines = 0
172
+ if os.path.exists('uv.lock'):
173
+ with open('uv.lock', 'r') as f:
174
+ uv_lock_lines = len(f.readlines())
175
+
176
+ header += f"""
177
+ ### Bloat
178
+ - Characters: {num_chars:,}
179
+ - Lines: {num_lines:,}
180
+ - Files: {num_files:,}
181
+ - Tokens (approx): {num_tokens:,}
182
+ - Dependencies (uv.lock lines): {uv_lock_lines:,}
183
+
184
+ """
185
+ return header
186
+
187
+ # -----------------------------------------------------------------------------
188
+
189
+ def slugify(text):
190
+ """Slugify a text string."""
191
+ return text.lower().replace(" ", "-")
192
+
193
+ # the expected files and their order
194
+ EXPECTED_FILES = [
195
+ "tokenizer-training.md",
196
+ "tokenizer-evaluation.md",
197
+ "base-model-training.md",
198
+ "base-model-loss.md",
199
+ "base-model-evaluation.md",
200
+ "midtraining.md",
201
+ "chat-evaluation-mid.md",
202
+ "chat-sft.md",
203
+ "chat-evaluation-sft.md",
204
+ "chat-rl.md",
205
+ "chat-evaluation-rl.md",
206
+ ]
207
+ # the metrics we're currently interested in
208
+ chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
209
+
210
+ def extract(section, keys):
211
+ """simple def to extract a single key from a section"""
212
+ if not isinstance(keys, list):
213
+ keys = [keys] # convenience
214
+ out = {}
215
+ for line in section.split("\n"):
216
+ for key in keys:
217
+ if key in line:
218
+ out[key] = line.split(":")[1].strip()
219
+ return out
220
+
221
+ def extract_timestamp(content, prefix):
222
+ """Extract timestamp from content with given prefix."""
223
+ for line in content.split('\n'):
224
+ if line.startswith(prefix):
225
+ time_str = line.split(":", 1)[1].strip()
226
+ try:
227
+ return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
228
+ except:
229
+ pass
230
+ return None
231
+
232
+ class Report:
233
+ """Maintains a bunch of logs, generates a final markdown report."""
234
+
235
+ def __init__(self, report_dir):
236
+ os.makedirs(report_dir, exist_ok=True)
237
+ self.report_dir = report_dir
238
+
239
+ def log(self, section, data):
240
+ """Log a section of data to the report."""
241
+ slug = slugify(section)
242
+ file_name = f"{slug}.md"
243
+ file_path = os.path.join(self.report_dir, file_name)
244
+ with open(file_path, "w") as f:
245
+ f.write(f"## {section}\n")
246
+ f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
247
+ for item in data:
248
+ if not item:
249
+ # skip falsy values like None or empty dict etc.
250
+ continue
251
+ if isinstance(item, str):
252
+ # directly write the string
253
+ f.write(item)
254
+ else:
255
+ # render a dict
256
+ for k, v in item.items():
257
+ if isinstance(v, float):
258
+ vstr = f"{v:.4f}"
259
+ elif isinstance(v, int) and v >= 10000:
260
+ vstr = f"{v:,.0f}"
261
+ else:
262
+ vstr = str(v)
263
+ f.write(f"- {k}: {vstr}\n")
264
+ f.write("\n")
265
+ return file_path
266
+
267
+ def generate(self):
268
+ """Generate the final report."""
269
+ report_dir = self.report_dir
270
+ report_file = os.path.join(report_dir, "report.md")
271
+ print(f"Generating report to {report_file}")
272
+ final_metrics = {} # the most important final metrics we'll add as table at the end
273
+ start_time = None
274
+ end_time = None
275
+ with open(report_file, "w") as out_file:
276
+ # write the header first
277
+ header_file = os.path.join(report_dir, "header.md")
278
+ if os.path.exists(header_file):
279
+ with open(header_file, "r") as f:
280
+ header_content = f.read()
281
+ out_file.write(header_content)
282
+ start_time = extract_timestamp(header_content, "Run started:")
283
+ # capture bloat data for summary later (the stuff after Bloat header and until \n\n)
284
+ bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
285
+ bloat_data = bloat_data.group(1) if bloat_data else ""
286
+ # process all the individual sections
287
+ for file_name in EXPECTED_FILES:
288
+ section_file = os.path.join(report_dir, file_name)
289
+ if not os.path.exists(section_file):
290
+ print(f"Warning: {section_file} does not exist, skipping")
291
+ continue
292
+ with open(section_file, "r") as in_file:
293
+ section = in_file.read()
294
+ # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
295
+ if "rl" not in file_name:
296
+ # Skip RL sections for end_time calculation because RL is experimental
297
+ end_time = extract_timestamp(section, "timestamp:")
298
+ # extract the most important metrics from the sections
299
+ if file_name == "base-model-evaluation.md":
300
+ final_metrics["base"] = extract(section, "CORE")
301
+ if file_name == "chat-evaluation-mid.md":
302
+ final_metrics["mid"] = extract(section, chat_metrics)
303
+ if file_name == "chat-evaluation-sft.md":
304
+ final_metrics["sft"] = extract(section, chat_metrics)
305
+ if file_name == "chat-evaluation-rl.md":
306
+ final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
307
+ # append this section of the report
308
+ out_file.write(section)
309
+ out_file.write("\n")
310
+ # add the final metrics table
311
+ out_file.write("## Summary\n\n")
312
+ # Copy over the bloat metrics from the header
313
+ out_file.write(bloat_data)
314
+ out_file.write("\n\n")
315
+ # Collect all unique metric names
316
+ all_metrics = set()
317
+ for stage_metrics in final_metrics.values():
318
+ all_metrics.update(stage_metrics.keys())
319
+ # Custom ordering: CORE first, ChatCORE last, rest in middle
320
+ all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
321
+ # Fixed column widths
322
+ stages = ["base", "mid", "sft", "rl"]
323
+ metric_width = 15
324
+ value_width = 8
325
+ # Write table header
326
+ header = f"| {'Metric'.ljust(metric_width)} |"
327
+ for stage in stages:
328
+ header += f" {stage.upper().ljust(value_width)} |"
329
+ out_file.write(header + "\n")
330
+ # Write separator
331
+ separator = f"|{'-' * (metric_width + 2)}|"
332
+ for stage in stages:
333
+ separator += f"{'-' * (value_width + 2)}|"
334
+ out_file.write(separator + "\n")
335
+ # Write table rows
336
+ for metric in all_metrics:
337
+ row = f"| {metric.ljust(metric_width)} |"
338
+ for stage in stages:
339
+ value = final_metrics.get(stage, {}).get(metric, "-")
340
+ row += f" {str(value).ljust(value_width)} |"
341
+ out_file.write(row + "\n")
342
+ out_file.write("\n")
343
+ # Calculate and write total wall clock time
344
+ if start_time and end_time:
345
+ duration = end_time - start_time
346
+ total_seconds = int(duration.total_seconds())
347
+ hours = total_seconds // 3600
348
+ minutes = (total_seconds % 3600) // 60
349
+ out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
350
+ else:
351
+ out_file.write("Total wall clock time: unknown\n")
352
+ # also cp the report.md file to current directory
353
+ print(f"Copying report.md to current directory for convenience")
354
+ shutil.copy(report_file, "report.md")
355
+ return report_file
356
+
357
+ def reset(self):
358
+ """Reset the report."""
359
+ # Remove section files
360
+ for file_name in EXPECTED_FILES:
361
+ file_path = os.path.join(self.report_dir, file_name)
362
+ if os.path.exists(file_path):
363
+ os.remove(file_path)
364
+ # Remove report.md if it exists
365
+ report_file = os.path.join(self.report_dir, "report.md")
366
+ if os.path.exists(report_file):
367
+ os.remove(report_file)
368
+ # Generate and write the header section with start timestamp
369
+ header_file = os.path.join(self.report_dir, "header.md")
370
+ header = generate_header()
371
+ start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
372
+ with open(header_file, "w") as f:
373
+ f.write(header)
374
+ f.write(f"Run started: {start_time}\n\n---\n\n")
375
+ print(f"Reset report and wrote header to {header_file}")
376
+
377
+ # -----------------------------------------------------------------------------
378
+ # nanochat-specific convenience functions
379
+
380
+ class DummyReport:
381
+ def log(self, *args, **kwargs):
382
+ pass
383
+ def reset(self, *args, **kwargs):
384
+ pass
385
+
386
+ def get_report():
387
+ # just for convenience, only rank 0 logs to report
388
+ from nanochat.common import get_base_dir, get_dist_info
389
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
390
+ if ddp_rank == 0:
391
+ report_dir = os.path.join(get_base_dir(), "report")
392
+ return Report(report_dir)
393
+ else:
394
+ return DummyReport()
395
+
396
+ if __name__ == "__main__":
397
+ import argparse
398
+ parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
399
+ parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
400
+ args = parser.parse_args()
401
+ if args.command == "generate":
402
+ get_report().generate()
403
+ elif args.command == "reset":
404
+ get_report().reset()
nanochat/tokenizer.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE Tokenizer in the style of GPT-4.
3
+
4
+ Two implementations are available:
5
+ 1) HuggingFace Tokenizer that can do both training and inference but is really confusing
6
+ 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
7
+ """
8
+
9
+ import os
10
+ import copy
11
+ from functools import lru_cache
12
+
13
+ SPECIAL_TOKENS = [
14
+ # every document begins with the Beginning of Sequence (BOS) token that delimits documents
15
+ "<|bos|>",
16
+ # tokens below are only used during finetuning to render Conversations into token ids
17
+ "<|user_start|>", # user messages
18
+ "<|user_end|>",
19
+ "<|assistant_start|>", # assistant messages
20
+ "<|assistant_end|>",
21
+ "<|python_start|>", # assistant invokes python REPL tool
22
+ "<|python_end|>",
23
+ "<|output_start|>", # python REPL outputs back to assistant
24
+ "<|output_end|>",
25
+ ]
26
+
27
+ # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
28
+ # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
29
+ # I haven't validated that this is actually a good idea, TODO.
30
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
34
+ from tokenizers import Tokenizer as HFTokenizer
35
+ from tokenizers import pre_tokenizers, decoders, Regex
36
+ from tokenizers.models import BPE
37
+ from tokenizers.trainers import BpeTrainer
38
+
39
+ class HuggingFaceTokenizer:
40
+ """Light wrapper around HuggingFace Tokenizer for some utilities"""
41
+
42
+ def __init__(self, tokenizer):
43
+ self.tokenizer = tokenizer
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, hf_path):
47
+ # init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
48
+ tokenizer = HFTokenizer.from_pretrained(hf_path)
49
+ return cls(tokenizer)
50
+
51
+ @classmethod
52
+ def from_directory(cls, tokenizer_dir):
53
+ # init from a local directory on disk (e.g. "out/tokenizer")
54
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
55
+ tokenizer = HFTokenizer.from_file(tokenizer_path)
56
+ return cls(tokenizer)
57
+
58
+ @classmethod
59
+ def train_from_iterator(cls, text_iterator, vocab_size):
60
+ # train from an iterator of text
61
+ # Configure the HuggingFace Tokenizer
62
+ tokenizer = HFTokenizer(BPE(
63
+ byte_fallback=True, # needed!
64
+ unk_token=None,
65
+ fuse_unk=False,
66
+ ))
67
+ # Normalizer: None
68
+ tokenizer.normalizer = None
69
+ # Pre-tokenizer: GPT-4 style
70
+ # the regex pattern used by GPT-4 to split text into groups before BPE
71
+ # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
72
+ # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
73
+ # (but I haven't validated this! TODO)
74
+ gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
75
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
76
+ pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
77
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
78
+ ])
79
+ # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
80
+ tokenizer.decoder = decoders.ByteLevel()
81
+ # Post-processor: None
82
+ tokenizer.post_processor = None
83
+ # Trainer: BPE
84
+ trainer = BpeTrainer(
85
+ vocab_size=vocab_size,
86
+ show_progress=True,
87
+ min_frequency=0, # no minimum frequency
88
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
89
+ special_tokens=SPECIAL_TOKENS,
90
+ )
91
+ # Kick off the training
92
+ tokenizer.train_from_iterator(text_iterator, trainer)
93
+ return cls(tokenizer)
94
+
95
+ def get_vocab_size(self):
96
+ return self.tokenizer.get_vocab_size()
97
+
98
+ def get_special_tokens(self):
99
+ special_tokens_map = self.tokenizer.get_added_tokens_decoder()
100
+ special_tokens = [w.content for w in special_tokens_map.values()]
101
+ return special_tokens
102
+
103
+ def id_to_token(self, id):
104
+ return self.tokenizer.id_to_token(id)
105
+
106
+ def _encode_one(self, text, prepend=None, append=None):
107
+ # encode a single string
108
+ # prepend/append can be either a string of a special token or a token id directly.
109
+ assert isinstance(text, str)
110
+ ids = []
111
+ if prepend is not None:
112
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
113
+ ids.append(prepend_id)
114
+ ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
115
+ if append is not None:
116
+ append_id = append if isinstance(append, int) else self.encode_special(append)
117
+ ids.append(append_id)
118
+ return ids
119
+
120
+ def encode_special(self, text):
121
+ # encode a single special token via exact match
122
+ return self.tokenizer.token_to_id(text)
123
+
124
+ def get_bos_token_id(self):
125
+ bos = self.encode_special("<|bos|>")
126
+ return bos
127
+
128
+ def encode(self, text, *args, **kwargs):
129
+ if isinstance(text, str):
130
+ return self._encode_one(text, *args, **kwargs)
131
+ elif isinstance(text, list):
132
+ return [self._encode_one(t, *args, **kwargs) for t in text]
133
+ else:
134
+ raise ValueError(f"Invalid input type: {type(text)}")
135
+
136
+ def __call__(self, *args, **kwargs):
137
+ return self.encode(*args, **kwargs)
138
+
139
+ def decode(self, ids):
140
+ return self.tokenizer.decode(ids, skip_special_tokens=False)
141
+
142
+ def save(self, tokenizer_dir):
143
+ # save the tokenizer to disk
144
+ os.makedirs(tokenizer_dir, exist_ok=True)
145
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
146
+ self.tokenizer.save(tokenizer_path)
147
+ print(f"Saved tokenizer to {tokenizer_path}")
148
+
149
+ # -----------------------------------------------------------------------------
150
+ # Tokenizer based on rustbpe + tiktoken combo
151
+ import pickle
152
+ import rustbpe
153
+ import tiktoken
154
+
155
+ class RustBPETokenizer:
156
+ """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
157
+
158
+ def __init__(self, enc, bos_token):
159
+ self.enc = enc
160
+ self.bos_token_id = self.encode_special(bos_token)
161
+
162
+ @classmethod
163
+ def train_from_iterator(cls, text_iterator, vocab_size):
164
+ # 1) train using rustbpe
165
+ tokenizer = rustbpe.Tokenizer()
166
+ # the special tokens are inserted later in __init__, we don't train them here
167
+ vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
168
+ assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
169
+ tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
170
+ # 2) construct the associated tiktoken encoding for inference
171
+ pattern = tokenizer.get_pattern()
172
+ mergeable_ranks_list = tokenizer.get_mergeable_ranks()
173
+ mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
174
+ tokens_offset = len(mergeable_ranks)
175
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
176
+ enc = tiktoken.Encoding(
177
+ name="rustbpe",
178
+ pat_str=pattern,
179
+ mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
180
+ special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
181
+ )
182
+ return cls(enc, "<|bos|>")
183
+
184
+ @classmethod
185
+ def from_directory(cls, tokenizer_dir):
186
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
187
+ with open(pickle_path, "rb") as f:
188
+ enc = pickle.load(f)
189
+ return cls(enc, "<|bos|>")
190
+
191
+ @classmethod
192
+ def from_pretrained(cls, tiktoken_name):
193
+ # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
194
+ enc = tiktoken.get_encoding(tiktoken_name)
195
+ # tiktoken calls the special document delimiter token "<|endoftext|>"
196
+ # yes this is confusing because this token is almost always PREPENDED to the beginning of the document
197
+ # it most often is used to signal the start of a new sequence to the LLM during inference etc.
198
+ # so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
199
+ return cls(enc, "<|endoftext|>")
200
+
201
+ def get_vocab_size(self):
202
+ return self.enc.n_vocab
203
+
204
+ def get_special_tokens(self):
205
+ return self.enc.special_tokens_set
206
+
207
+ def id_to_token(self, id):
208
+ return self.enc.decode([id])
209
+
210
+ @lru_cache(maxsize=32)
211
+ def encode_special(self, text):
212
+ return self.enc.encode_single_token(text)
213
+
214
+ def get_bos_token_id(self):
215
+ return self.bos_token_id
216
+
217
+ def encode(self, text, prepend=None, append=None, num_threads=8):
218
+ # text can be either a string or a list of strings
219
+
220
+ if prepend is not None:
221
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
222
+ if append is not None:
223
+ append_id = append if isinstance(append, int) else self.encode_special(append)
224
+
225
+ if isinstance(text, str):
226
+ ids = self.enc.encode_ordinary(text)
227
+ if prepend is not None:
228
+ ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
229
+ if append is not None:
230
+ ids.append(append_id)
231
+ elif isinstance(text, list):
232
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
233
+ if prepend is not None:
234
+ for ids_row in ids:
235
+ ids_row.insert(0, prepend_id) # TODO: same
236
+ if append is not None:
237
+ for ids_row in ids:
238
+ ids_row.append(append_id)
239
+ else:
240
+ raise ValueError(f"Invalid input type: {type(text)}")
241
+
242
+ return ids
243
+
244
+ def __call__(self, *args, **kwargs):
245
+ return self.encode(*args, **kwargs)
246
+
247
+ def decode(self, ids):
248
+ return self.enc.decode(ids)
249
+
250
+ def save(self, tokenizer_dir):
251
+ # save the encoding object to disk
252
+ os.makedirs(tokenizer_dir, exist_ok=True)
253
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
254
+ with open(pickle_path, "wb") as f:
255
+ pickle.dump(self.enc, f)
256
+ print(f"Saved tokenizer encoding to {pickle_path}")
257
+
258
+ def render_conversation(self, conversation, max_tokens=2048):
259
+ """
260
+ Tokenize a single Chat conversation (which we call a "doc" or "document" here).
261
+ Returns:
262
+ - ids: list[int] is a list of token ids of this rendered conversation
263
+ - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
264
+ """
265
+ # ids, masks that we will return and a helper function to help build them up.
266
+ ids, mask = [], []
267
+ def add_tokens(token_ids, mask_val):
268
+ if isinstance(token_ids, int):
269
+ token_ids = [token_ids]
270
+ ids.extend(token_ids)
271
+ mask.extend([mask_val] * len(token_ids))
272
+
273
+ # sometimes the first message is a system message...
274
+ # => just merge it with the second (user) message
275
+ if conversation["messages"][0]["role"] == "system":
276
+ # some conversation surgery is necessary here for now...
277
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
278
+ messages = conversation["messages"]
279
+ assert messages[1]["role"] == "user", "System message must be followed by a user message"
280
+ messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
281
+ messages = messages[1:]
282
+ else:
283
+ messages = conversation["messages"]
284
+ assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
285
+
286
+ # fetch all the special tokens we need
287
+ bos = self.get_bos_token_id()
288
+ user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
289
+ assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
290
+ python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
291
+ output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
292
+
293
+ # now we can tokenize the conversation
294
+ add_tokens(bos, 0)
295
+ for i, message in enumerate(messages):
296
+
297
+ # some sanity checking here around assumptions, to prevent footguns
298
+ must_be_from = "user" if i % 2 == 0 else "assistant"
299
+ assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
300
+
301
+ # content can be either a simple string or a list of parts (e.g. containing tool calls)
302
+ content = message["content"]
303
+
304
+ if message["role"] == "user":
305
+ assert isinstance(content, str), "User messages are simply expected to be strings"
306
+ value_ids = self.encode(content)
307
+ add_tokens(user_start, 0)
308
+ add_tokens(value_ids, 0)
309
+ add_tokens(user_end, 0)
310
+ elif message["role"] == "assistant":
311
+ add_tokens(assistant_start, 0)
312
+ if isinstance(content, str):
313
+ # simple string => simply add the tokens
314
+ value_ids = self.encode(content)
315
+ add_tokens(value_ids, 1)
316
+ elif isinstance(content, list):
317
+ for part in content:
318
+ value_ids = self.encode(part["text"])
319
+ if part["type"] == "text":
320
+ # string part => simply add the tokens
321
+ add_tokens(value_ids, 1)
322
+ elif part["type"] == "python":
323
+ # python tool call => add the tokens inside <|python_start|> and <|python_end|>
324
+ add_tokens(python_start, 1)
325
+ add_tokens(value_ids, 1)
326
+ add_tokens(python_end, 1)
327
+ elif part["type"] == "python_output":
328
+ # python output => add the tokens inside <|output_start|> and <|output_end|>
329
+ # none of these tokens are supervised because the tokens come from Python at test time
330
+ add_tokens(output_start, 0)
331
+ add_tokens(value_ids, 0)
332
+ add_tokens(output_end, 0)
333
+ else:
334
+ raise ValueError(f"Unknown part type: {part['type']}")
335
+ else:
336
+ raise ValueError(f"Unknown content type: {type(content)}")
337
+ add_tokens(assistant_end, 1)
338
+
339
+ # truncate to max_tokens tokens MAX (helps prevent OOMs)
340
+ ids = ids[:max_tokens]
341
+ mask = mask[:max_tokens]
342
+ return ids, mask
343
+
344
+ def visualize_tokenization(self, ids, mask):
345
+ """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
346
+ RED = '\033[91m'
347
+ GREEN = '\033[92m'
348
+ RESET = '\033[0m'
349
+ tokens = []
350
+ for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
351
+ token_str = self.decode([token_id])
352
+ color = GREEN if mask_val == 1 else RED
353
+ tokens.append(f"{color}{token_str}{RESET}")
354
+ return '|'.join(tokens)
355
+
356
+ def render_for_completion(self, conversation):
357
+ """
358
+ Used during Reinforcement Learning. In that setting, we want to
359
+ render the conversation priming the Assistant for a completion.
360
+ Unlike the Chat SFT case, we don't need to return the mask.
361
+ """
362
+ # We have some surgery to do: we need to pop the last message (of the Assistant)
363
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
364
+ messages = conversation["messages"]
365
+ assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
366
+ messages.pop() # remove the last message (of the Assistant) inplace
367
+
368
+ # Now tokenize the conversation
369
+ ids, mask = self.render_conversation(conversation)
370
+
371
+ # Finally, to prime the Assistant for a completion, append the Assistant start token
372
+ assistant_start = self.encode_special("<|assistant_start|>")
373
+ ids.append(assistant_start)
374
+ return ids
375
+
376
+ # -----------------------------------------------------------------------------
377
+ # nanochat-specific convenience functions
378
+
379
+ def get_tokenizer():
380
+ from nanochat.common import get_base_dir
381
+ base_dir = get_base_dir()
382
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
383
+ # return HuggingFaceTokenizer.from_directory(tokenizer_dir)
384
+ return RustBPETokenizer.from_directory(tokenizer_dir)
385
+
386
+ def get_token_bytes(device="cpu"):
387
+ import torch
388
+ from nanochat.common import get_base_dir
389
+ base_dir = get_base_dir()
390
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
391
+ token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
392
+ assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
393
+ with open(token_bytes_path, "rb") as f:
394
+ token_bytes = torch.load(f, map_location=device)
395
+ return token_bytes
nanochat/ui.html ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>NanoChat</title>
7
+ <link rel="icon" type="image/svg+xml" href="/logo.svg">
8
+ <style>
9
+ :root {
10
+ color-scheme: light;
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ body {
18
+ font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
19
+ background-color: #ffffff;
20
+ color: #111827;
21
+ min-height: 100vh;
22
+ margin: 0;
23
+ display: flex;
24
+ flex-direction: column;
25
+ }
26
+
27
+ .header {
28
+ background-color: #ffffff;
29
+ padding: 1.25rem 1.5rem;
30
+ }
31
+
32
+ .header-left {
33
+ display: flex;
34
+ align-items: center;
35
+ gap: 0.75rem;
36
+ }
37
+
38
+ .header-logo {
39
+ height: 32px;
40
+ width: auto;
41
+ }
42
+
43
+ .header h1 {
44
+ font-size: 1.25rem;
45
+ font-weight: 600;
46
+ margin: 0;
47
+ color: #111827;
48
+ }
49
+
50
+ .new-conversation-btn {
51
+ width: 32px;
52
+ height: 32px;
53
+ padding: 0;
54
+ border: 1px solid #e5e7eb;
55
+ border-radius: 0.5rem;
56
+ background-color: #ffffff;
57
+ color: #6b7280;
58
+ cursor: pointer;
59
+ display: flex;
60
+ align-items: center;
61
+ justify-content: center;
62
+ transition: all 0.2s ease;
63
+ }
64
+
65
+ .new-conversation-btn:hover {
66
+ background-color: #f3f4f6;
67
+ border-color: #d1d5db;
68
+ color: #374151;
69
+ }
70
+
71
+ .chat-container {
72
+ flex: 1;
73
+ overflow-y: auto;
74
+ background-color: #ffffff;
75
+ }
76
+
77
+ .chat-wrapper {
78
+ max-width: 48rem;
79
+ margin: 0 auto;
80
+ padding: 2rem 1.5rem 3rem;
81
+ display: flex;
82
+ flex-direction: column;
83
+ gap: 0.75rem;
84
+ }
85
+
86
+ .message {
87
+ display: flex;
88
+ justify-content: flex-start;
89
+ margin-bottom: 0.5rem;
90
+ color: #0d0d0d;
91
+ }
92
+
93
+ .message.assistant {
94
+ justify-content: flex-start;
95
+ }
96
+
97
+ .message.user {
98
+ justify-content: flex-end;
99
+ }
100
+
101
+ .message-content {
102
+ white-space: pre-wrap;
103
+ line-height: 1.6;
104
+ max-width: 100%;
105
+ }
106
+
107
+ .message.assistant .message-content {
108
+ background: transparent;
109
+ border: none;
110
+ padding: 0.25rem 0;
111
+ }
112
+
113
+ .message.user .message-content {
114
+ background-color: #f3f4f6;
115
+ border-radius: 1.25rem;
116
+ padding: 0.8rem 1rem;
117
+ max-width: 65%;
118
+ }
119
+
120
+ .input-container {
121
+ background-color: #ffffff;
122
+ padding: 1rem;
123
+ }
124
+
125
+ .input-wrapper {
126
+ max-width: 48rem;
127
+ margin: 0 auto;
128
+ display: flex;
129
+ gap: 0.75rem;
130
+ align-items: flex-end;
131
+ }
132
+
133
+ .chat-input {
134
+ flex: 1;
135
+ padding: 0.8rem 1rem;
136
+ border: 1px solid #d1d5db;
137
+ border-radius: 0.75rem;
138
+ background-color: #ffffff;
139
+ color: #111827;
140
+ font-size: 1rem;
141
+ line-height: 1.5;
142
+ resize: none;
143
+ outline: none;
144
+ min-height: 54px;
145
+ max-height: 200px;
146
+ transition: border-color 0.2s ease, box-shadow 0.2s ease;
147
+ }
148
+
149
+ .chat-input::placeholder {
150
+ color: #9ca3af;
151
+ }
152
+
153
+ .chat-input:focus {
154
+ border-color: #2563eb;
155
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
156
+ }
157
+
158
+ .send-button {
159
+ flex-shrink: 0;
160
+ padding: 0;
161
+ width: 54px;
162
+ height: 54px;
163
+ border: 1px solid #111827;
164
+ border-radius: 0.75rem;
165
+ background-color: #111827;
166
+ color: #ffffff;
167
+ display: flex;
168
+ align-items: center;
169
+ justify-content: center;
170
+ cursor: pointer;
171
+ transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
172
+ }
173
+
174
+ .send-button:hover:not(:disabled) {
175
+ background-color: #2563eb;
176
+ border-color: #2563eb;
177
+ }
178
+
179
+ .send-button:disabled {
180
+ cursor: not-allowed;
181
+ border-color: #d1d5db;
182
+ background-color: #e5e7eb;
183
+ color: #9ca3af;
184
+ }
185
+
186
+ .typing-indicator {
187
+ display: inline-block;
188
+ color: #6b7280;
189
+ letter-spacing: 0.15em;
190
+ }
191
+
192
+ .typing-indicator::after {
193
+ content: '···';
194
+ animation: typing 1.4s infinite;
195
+ }
196
+
197
+ @keyframes typing {
198
+ 0%, 60%, 100% { opacity: 0.2; }
199
+ 30% { opacity: 1; }
200
+ }
201
+
202
+ .error-message {
203
+ background-color: #fee2e2;
204
+ border: 1px solid #fecaca;
205
+ color: #b91c1c;
206
+ padding: 0.75rem 1rem;
207
+ border-radius: 0.75rem;
208
+ margin-top: 0.5rem;
209
+ }
210
+ </style>
211
+ </head>
212
+ <body>
213
+ <div class="header">
214
+ <div class="header-left">
215
+ <button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
216
+ <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
217
+ <path d="M12 5v14"></path>
218
+ <path d="M5 12h14"></path>
219
+ </svg>
220
+ </button>
221
+ <h1>nanochat</h1>
222
+ </div>
223
+ </div>
224
+
225
+ <div class="chat-container" id="chatContainer">
226
+ <div class="chat-wrapper" id="chatWrapper">
227
+ <!-- Messages will be added here -->
228
+ </div>
229
+ </div>
230
+
231
+ <div class="input-container">
232
+ <div class="input-wrapper">
233
+ <textarea
234
+ id="chatInput"
235
+ class="chat-input"
236
+ placeholder="Ask anything"
237
+ rows="1"
238
+ onkeydown="handleKeyDown(event)"
239
+ ></textarea>
240
+ <button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
241
+ <svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
242
+ <path d="M22 2L11 13"></path>
243
+ <path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
244
+ </svg>
245
+ </button>
246
+ </div>
247
+ </div>
248
+
249
+ <script>
250
+ const API_URL = '';
251
+ const chatContainer = document.getElementById('chatContainer');
252
+ const chatWrapper = document.getElementById('chatWrapper');
253
+ const chatInput = document.getElementById('chatInput');
254
+ const sendButton = document.getElementById('sendButton');
255
+
256
+ let messages = [];
257
+ let isGenerating = false;
258
+
259
+ chatInput.addEventListener('input', function() {
260
+ this.style.height = 'auto';
261
+ this.style.height = Math.min(this.scrollHeight, 200) + 'px';
262
+ sendButton.disabled = !this.value.trim() || isGenerating;
263
+ });
264
+
265
+ function handleKeyDown(event) {
266
+ if (event.key === 'Enter' && !event.shiftKey) {
267
+ event.preventDefault();
268
+ sendMessage();
269
+ }
270
+ }
271
+
272
+ document.addEventListener('keydown', function(event) {
273
+ // Ctrl+Shift+N for new conversation
274
+ if (event.ctrlKey && event.shiftKey && event.key === 'N') {
275
+ event.preventDefault();
276
+ if (!isGenerating) {
277
+ newConversation();
278
+ }
279
+ }
280
+ });
281
+
282
+ function newConversation() {
283
+ messages = [];
284
+ chatWrapper.innerHTML = '';
285
+ chatInput.value = '';
286
+ chatInput.style.height = 'auto';
287
+ sendButton.disabled = false;
288
+ isGenerating = false;
289
+ chatInput.focus();
290
+ }
291
+
292
+ function addMessage(role, content) {
293
+ const messageDiv = document.createElement('div');
294
+ messageDiv.className = `message ${role}`;
295
+
296
+ const contentDiv = document.createElement('div');
297
+ contentDiv.className = 'message-content';
298
+ contentDiv.textContent = content;
299
+
300
+ messageDiv.appendChild(contentDiv);
301
+ chatWrapper.appendChild(messageDiv);
302
+
303
+ chatContainer.scrollTop = chatContainer.scrollHeight;
304
+ return contentDiv;
305
+ }
306
+
307
+ async function sendMessage() {
308
+ const message = chatInput.value.trim();
309
+ if (!message || isGenerating) return;
310
+
311
+ isGenerating = true;
312
+ chatInput.value = '';
313
+ chatInput.style.height = 'auto';
314
+ sendButton.disabled = true;
315
+
316
+ messages.push({ role: 'user', content: message });
317
+ addMessage('user', message);
318
+
319
+ const assistantContent = addMessage('assistant', '');
320
+ assistantContent.innerHTML = '<span class="typing-indicator"></span>';
321
+
322
+ try {
323
+ const response = await fetch(`${API_URL}/chat/completions`, {
324
+ method: 'POST',
325
+ headers: {
326
+ 'Content-Type': 'application/json',
327
+ },
328
+ body: JSON.stringify({
329
+ messages: messages,
330
+ stream: true,
331
+ temperature: 0.8,
332
+ max_tokens: 512
333
+ }),
334
+ });
335
+
336
+ if (!response.ok) {
337
+ throw new Error(`HTTP error! status: ${response.status}`);
338
+ }
339
+
340
+ const reader = response.body.getReader();
341
+ const decoder = new TextDecoder();
342
+ let fullResponse = '';
343
+ assistantContent.textContent = '';
344
+
345
+ while (true) {
346
+ const { done, value } = await reader.read();
347
+ if (done) break;
348
+
349
+ const chunk = decoder.decode(value);
350
+ const lines = chunk.split('\n');
351
+
352
+ for (const line of lines) {
353
+ if (line.startsWith('data: ')) {
354
+ try {
355
+ const data = JSON.parse(line.slice(6));
356
+ if (data.token) {
357
+ fullResponse += data.token;
358
+ assistantContent.textContent = fullResponse;
359
+ chatContainer.scrollTop = chatContainer.scrollHeight;
360
+ }
361
+ } catch (e) {
362
+ }
363
+ }
364
+ }
365
+ }
366
+
367
+ messages.push({ role: 'assistant', content: fullResponse });
368
+
369
+ } catch (error) {
370
+ console.error('Error:', error);
371
+ assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
372
+ } finally {
373
+ isGenerating = false;
374
+ sendButton.disabled = !chatInput.value.trim();
375
+ }
376
+ }
377
+
378
+ sendButton.disabled = false;
379
+
380
+ // Autofocus the chat input on page load
381
+ chatInput.focus();
382
+
383
+ fetch(`${API_URL}/health`)
384
+ .then(response => response.json())
385
+ .then(data => {
386
+ console.log('Engine status:', data);
387
+ })
388
+ .catch(error => {
389
+ console.error('Engine not available:', error);
390
+ chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
391
+ });
392
+ </script>
393
+ </body>
394
+ </html>