|
|
import glob |
|
|
import math |
|
|
import os |
|
|
import shutil |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from huggingface_hub import snapshot_download |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
from .configuration_nanogpt import NanoGPTConfig |
|
|
|
|
|
|
|
|
def _rms_norm(x: torch.Tensor) -> torch.Tensor: |
|
|
return F.rms_norm(x, (x.size(-1),)) |
|
|
|
|
|
|
|
|
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
|
assert x.ndim == 4 |
|
|
d = x.shape[3] // 2 |
|
|
x1, x2 = x[..., :d], x[..., d:] |
|
|
y1 = x1 * cos + x2 * sin |
|
|
y2 = x1 * (-sin) + x2 * cos |
|
|
out = torch.cat([y1, y2], 3) |
|
|
return out.to(x.dtype) |
|
|
|
|
|
|
|
|
def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
if n_rep == 1: |
|
|
return x |
|
|
bs, n_kv_heads, slen, head_dim = x.shape |
|
|
return ( |
|
|
x[:, :, None, :, :] |
|
|
.expand(bs, n_kv_heads, n_rep, slen, head_dim) |
|
|
.reshape(bs, n_kv_heads * n_rep, slen, head_dim) |
|
|
) |
|
|
|
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
def __init__(self, config: NanoGPTConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.layer_idx = layer_idx |
|
|
self.n_head = config.n_head |
|
|
self.n_kv_head = config.n_kv_head |
|
|
self.n_embd = config.n_embd |
|
|
self.head_dim = self.n_embd // self.n_head |
|
|
assert self.n_embd % self.n_head == 0 |
|
|
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 |
|
|
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) |
|
|
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) |
|
|
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) |
|
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor: |
|
|
B, T, C = x.size() |
|
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim) |
|
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) |
|
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) |
|
|
cos, sin = cos_sin |
|
|
q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin) |
|
|
q, k = _rms_norm(q), _rms_norm(k) |
|
|
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
Tq = q.size(2) |
|
|
Tk = k.size(2) |
|
|
nrep = self.n_head // self.n_kv_head |
|
|
k, v = _repeat_kv(k, nrep), _repeat_kv(v, nrep) |
|
|
if Tq == Tk: |
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
|
elif Tq == 1: |
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=False) |
|
|
else: |
|
|
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) |
|
|
prefix_len = Tk - Tq |
|
|
if prefix_len > 0: |
|
|
attn_mask[:, :prefix_len] = True |
|
|
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) |
|
|
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) |
|
|
y = y.transpose(1, 2).contiguous().view(B, T, -1) |
|
|
y = self.c_proj(y) |
|
|
return y |
|
|
|
|
|
def forward_with_cache( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
cos_sin, |
|
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
B, T, _ = x.size() |
|
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim) |
|
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) |
|
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) |
|
|
cos, sin = cos_sin |
|
|
q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin) |
|
|
q, k = _rms_norm(q), _rms_norm(k) |
|
|
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
|
|
|
if past_key_value is not None: |
|
|
past_k, past_v = past_key_value |
|
|
if past_k is not None and past_v is not None: |
|
|
k = torch.cat([past_k, k], dim=2) |
|
|
v = torch.cat([past_v, v], dim=2) |
|
|
|
|
|
present = (k, v) if use_cache else None |
|
|
|
|
|
Tq = q.size(2) |
|
|
Tk = k.size(2) |
|
|
nrep = self.n_head // self.n_kv_head |
|
|
k_rep = _repeat_kv(k, nrep) |
|
|
v_rep = _repeat_kv(v, nrep) |
|
|
|
|
|
attn_mask = None |
|
|
if attention_mask is not None: |
|
|
attn_mask = attention_mask.to(dtype=torch.bool, device=q.device) |
|
|
if attn_mask.dim() == 2: |
|
|
attn_mask = attn_mask[:, None, None, :] |
|
|
elif attn_mask.dim() == 4: |
|
|
pass |
|
|
else: |
|
|
raise ValueError("Unsupported attention_mask dimensions") |
|
|
if attn_mask.size(-1) != Tk: |
|
|
attn_mask = torch.nn.functional.pad(attn_mask, (Tk - attn_mask.size(-1), 0)) |
|
|
attn_mask = (~attn_mask).to(dtype=q.dtype) * -1e4 |
|
|
|
|
|
if Tq == Tk: |
|
|
y = F.scaled_dot_product_attention(q, k_rep, v_rep, attn_mask=attn_mask, is_causal=True) |
|
|
else: |
|
|
y = F.scaled_dot_product_attention(q, k_rep, v_rep, attn_mask=attn_mask, is_causal=False) |
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, -1) |
|
|
y = self.c_proj(y) |
|
|
return y, present |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config: NanoGPTConfig): |
|
|
super().__init__() |
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) |
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.c_fc(x) |
|
|
x = F.relu(x).square() |
|
|
x = self.c_proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config: NanoGPTConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.attn = CausalSelfAttention(config, layer_idx) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor: |
|
|
x = x + self.attn(_rms_norm(x), cos_sin, kv_cache) |
|
|
x = x + self.mlp(_rms_norm(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class NanoGPTModel(PreTrainedModel): |
|
|
config_class = NanoGPTConfig |
|
|
|
|
|
_CANONICAL_WEIGHT_NAMES = ( |
|
|
"pytorch_model.bin", |
|
|
"model.safetensors", |
|
|
"model.ckpt.index", |
|
|
"tf_model.h5", |
|
|
"flax_model.msgpack", |
|
|
) |
|
|
_PT_PATTERN = "model_*.pt" |
|
|
|
|
|
@classmethod |
|
|
def _snapshot_kwargs(cls, source_kwargs: Dict) -> Dict: |
|
|
keys = { |
|
|
"cache_dir", |
|
|
"force_download", |
|
|
"local_files_only", |
|
|
"proxies", |
|
|
"resume_download", |
|
|
"revision", |
|
|
"token", |
|
|
"use_auth_token", |
|
|
} |
|
|
return {k: source_kwargs[k] for k in keys if k in source_kwargs} |
|
|
|
|
|
@classmethod |
|
|
def _resolve_checkpoint_dir(cls, pretrained_model_name_or_path, subfolder=None, **kwargs): |
|
|
if os.path.isdir(pretrained_model_name_or_path): |
|
|
base_dir = pretrained_model_name_or_path |
|
|
else: |
|
|
snapshot_params = cls._snapshot_kwargs(kwargs) |
|
|
token = snapshot_params.pop("token", None) |
|
|
if token is None: |
|
|
token = snapshot_params.pop("use_auth_token", None) |
|
|
if token is not None: |
|
|
snapshot_params["token"] = token |
|
|
base_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_params) |
|
|
if subfolder: |
|
|
base_dir = os.path.join(base_dir, subfolder) |
|
|
cls._ensure_canonical_weights(base_dir) |
|
|
return base_dir |
|
|
|
|
|
@classmethod |
|
|
def _ensure_canonical_weights(cls, checkpoint_dir): |
|
|
for name in cls._CANONICAL_WEIGHT_NAMES: |
|
|
candidate = os.path.join(checkpoint_dir, name) |
|
|
if os.path.isfile(candidate): |
|
|
return candidate |
|
|
pt_candidates = sorted( |
|
|
glob.glob(os.path.join(checkpoint_dir, cls._PT_PATTERN)), |
|
|
reverse=True, |
|
|
) |
|
|
if not pt_candidates: |
|
|
raise FileNotFoundError( |
|
|
f"No checkpoint weights found in {checkpoint_dir}. Expected one of {cls._CANONICAL_WEIGHT_NAMES} " |
|
|
f"or files matching {cls._PT_PATTERN}." |
|
|
) |
|
|
source_path = pt_candidates[0] |
|
|
target_path = os.path.join(checkpoint_dir, "pytorch_model.bin") |
|
|
if ( |
|
|
not os.path.isfile(target_path) |
|
|
or os.path.getmtime(source_path) > os.path.getmtime(target_path) |
|
|
): |
|
|
shutil.copyfile(source_path, target_path) |
|
|
return target_path |
|
|
|
|
|
def __init__(self, config: NanoGPTConfig): |
|
|
super().__init__(config) |
|
|
config.use_cache = getattr(config, "use_cache", True) |
|
|
config.num_hidden_layers = config.n_layer |
|
|
config.num_attention_heads = config.n_head |
|
|
config.hidden_size = config.n_embd |
|
|
self.transformer = nn.ModuleDict({ |
|
|
"wte": nn.Embedding(config.vocab_size, config.n_embd), |
|
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), |
|
|
}) |
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
self.rotary_seq_len = config.sequence_len * 10 |
|
|
head_dim = config.n_embd // config.n_head |
|
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) |
|
|
self.register_buffer("cos", cos, persistent=False) |
|
|
self.register_buffer("sin", sin, persistent=False) |
|
|
|
|
|
self.transformer.wte.to(dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module: nn.Module): |
|
|
if isinstance(module, nn.Linear): |
|
|
fan_out = module.weight.size(0) |
|
|
fan_in = module.weight.size(1) |
|
|
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) |
|
|
|
|
|
def _precompute_rotary_embeddings(self, seq_len: int, head_dim: int, base: int = 10000, device=None): |
|
|
if device is None: |
|
|
device = self.transformer.wte.weight.device |
|
|
|
|
|
if device.type == 'meta': |
|
|
device = torch.device('cpu') |
|
|
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) |
|
|
inv_freq = 1.0 / (base ** (channel_range / head_dim)) |
|
|
t = torch.arange(seq_len, dtype=torch.float32, device=device) |
|
|
freqs = torch.outer(t, inv_freq) |
|
|
cos, sin = freqs.cos(), freqs.sin() |
|
|
cos, sin = cos.bfloat16(), sin.bfloat16() |
|
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :] |
|
|
return cos, sin |
|
|
|
|
|
def _apply_softcap(self, logits: torch.Tensor) -> torch.Tensor: |
|
|
softcap = 15 |
|
|
return softcap * torch.tanh(logits / softcap) |
|
|
|
|
|
def _forward_impl(self, idx: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor: |
|
|
x = self.transformer.wte(idx) |
|
|
x = x.float() |
|
|
x = _rms_norm(x) |
|
|
for block in self.transformer.h: |
|
|
x = block(x, cos_sin, kv_cache) |
|
|
x = _rms_norm(x) |
|
|
logits = self.lm_head(x) |
|
|
return self._apply_softcap(logits) |
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, labels=None, loss_reduction: str = 'mean', **kwargs): |
|
|
idx = input_ids |
|
|
B, T = idx.size() |
|
|
T0 = 0 |
|
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] |
|
|
logits = self._forward_impl(idx, cos_sin, kv_cache=None) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
labels.view(-1), |
|
|
ignore_index=-1, |
|
|
reduction=loss_reduction, |
|
|
) |
|
|
return {"loss": loss, "logits": logits} |
|
|
|
|
|
|
|
|
class NanoGPTChat(NanoGPTModel): |
|
|
"""Chat-optimized variant with HF-friendly generate and support for KV cache.""" |
|
|
|
|
|
def __init__(self, config: NanoGPTConfig): |
|
|
super().__init__(config) |
|
|
self.use_cache = getattr(config, "use_cache", True) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
|
|
if past_key_values is not None: |
|
|
input_ids = input_ids[:, -1:] |
|
|
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs} |
|
|
|
|
|
def _expand_past_length(self, past_key_values): |
|
|
if not past_key_values: |
|
|
return 0 |
|
|
past_k, _ = past_key_values[0] |
|
|
if past_k is None: |
|
|
return 0 |
|
|
return past_k.size(2) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
loss_reduction: str = "mean", |
|
|
**kwargs, |
|
|
) -> CausalLMOutputWithPast: |
|
|
idx = input_ids |
|
|
B, T = idx.size() |
|
|
use_cache = self.use_cache if use_cache is None else use_cache |
|
|
past_length = self._expand_past_length(past_key_values) |
|
|
T0 = past_length |
|
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] |
|
|
|
|
|
x = self.transformer.wte(idx) |
|
|
x = x.float() |
|
|
x = _rms_norm(x) |
|
|
|
|
|
presents = [] if use_cache else None |
|
|
for layer_idx, block in enumerate(self.transformer.h): |
|
|
past = None |
|
|
if past_key_values is not None and past_key_values[layer_idx] is not None: |
|
|
past = past_key_values[layer_idx] |
|
|
attn_output, present = block.attn.forward_with_cache( |
|
|
_rms_norm(x), |
|
|
cos_sin, |
|
|
past_key_value=past, |
|
|
attention_mask=attention_mask, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
x = x + attn_output |
|
|
x = x + block.mlp(_rms_norm(x)) |
|
|
if use_cache: |
|
|
presents.append(present) |
|
|
|
|
|
x = _rms_norm(x) |
|
|
logits = self.lm_head(x) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
labels.view(-1), |
|
|
ignore_index=-1, |
|
|
reduction=loss_reduction, |
|
|
) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=presents, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|