File size: 14,665 Bytes
788c379 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
import glob
import math
import os
import shutil
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_nanogpt import NanoGPTConfig
def _rms_norm(x: torch.Tensor) -> torch.Tensor:
return F.rms_norm(x, (x.size(-1),))
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
assert x.ndim == 4
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3)
return out.to(x.dtype)
def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
if n_rep == 1:
return x
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
class CausalSelfAttention(nn.Module):
def __init__(self, config: NanoGPTConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
B, T, C = x.size()
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
cos, sin = cos_sin
q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin)
q, k = _rms_norm(q), _rms_norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
Tq = q.size(2)
Tk = k.size(2)
nrep = self.n_head // self.n_kv_head
k, v = _repeat_kv(k, nrep), _repeat_kv(v, nrep)
if Tq == Tk:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
elif Tq == 1:
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
else:
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
prefix_len = Tk - Tq
if prefix_len > 0:
attn_mask[:, :prefix_len] = True
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
def forward_with_cache(
self,
x: torch.Tensor,
cos_sin,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, _ = x.size()
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
cos, sin = cos_sin
q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin)
q, k = _rms_norm(q), _rms_norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if past_key_value is not None:
past_k, past_v = past_key_value
if past_k is not None and past_v is not None:
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present = (k, v) if use_cache else None
Tq = q.size(2)
Tk = k.size(2)
nrep = self.n_head // self.n_kv_head
k_rep = _repeat_kv(k, nrep)
v_rep = _repeat_kv(v, nrep)
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask.to(dtype=torch.bool, device=q.device)
if attn_mask.dim() == 2:
attn_mask = attn_mask[:, None, None, :]
elif attn_mask.dim() == 4:
pass
else:
raise ValueError("Unsupported attention_mask dimensions")
if attn_mask.size(-1) != Tk:
attn_mask = torch.nn.functional.pad(attn_mask, (Tk - attn_mask.size(-1), 0))
attn_mask = (~attn_mask).to(dtype=q.dtype) * -1e4
if Tq == Tk:
y = F.scaled_dot_product_attention(q, k_rep, v_rep, attn_mask=attn_mask, is_causal=True)
else:
y = F.scaled_dot_product_attention(q, k_rep, v_rep, attn_mask=attn_mask, is_causal=False)
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y, present
class MLP(nn.Module):
def __init__(self, config: NanoGPTConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config: NanoGPTConfig, layer_idx: int):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
x = x + self.attn(_rms_norm(x), cos_sin, kv_cache)
x = x + self.mlp(_rms_norm(x))
return x
class NanoGPTModel(PreTrainedModel):
config_class = NanoGPTConfig
_CANONICAL_WEIGHT_NAMES = (
"pytorch_model.bin",
"model.safetensors",
"model.ckpt.index",
"tf_model.h5",
"flax_model.msgpack",
)
_PT_PATTERN = "model_*.pt"
@classmethod
def _snapshot_kwargs(cls, source_kwargs: Dict) -> Dict:
keys = {
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"token",
"use_auth_token",
}
return {k: source_kwargs[k] for k in keys if k in source_kwargs}
@classmethod
def _resolve_checkpoint_dir(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
if os.path.isdir(pretrained_model_name_or_path):
base_dir = pretrained_model_name_or_path
else:
snapshot_params = cls._snapshot_kwargs(kwargs)
token = snapshot_params.pop("token", None)
if token is None:
token = snapshot_params.pop("use_auth_token", None)
if token is not None:
snapshot_params["token"] = token
base_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_params)
if subfolder:
base_dir = os.path.join(base_dir, subfolder)
cls._ensure_canonical_weights(base_dir)
return base_dir
@classmethod
def _ensure_canonical_weights(cls, checkpoint_dir):
for name in cls._CANONICAL_WEIGHT_NAMES:
candidate = os.path.join(checkpoint_dir, name)
if os.path.isfile(candidate):
return candidate
pt_candidates = sorted(
glob.glob(os.path.join(checkpoint_dir, cls._PT_PATTERN)),
reverse=True,
)
if not pt_candidates:
raise FileNotFoundError(
f"No checkpoint weights found in {checkpoint_dir}. Expected one of {cls._CANONICAL_WEIGHT_NAMES} "
f"or files matching {cls._PT_PATTERN}."
)
source_path = pt_candidates[0]
target_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
if (
not os.path.isfile(target_path)
or os.path.getmtime(source_path) > os.path.getmtime(target_path)
):
shutil.copyfile(source_path, target_path)
return target_path
def __init__(self, config: NanoGPTConfig):
super().__init__(config)
config.use_cache = getattr(config, "use_cache", True)
config.num_hidden_layers = config.n_layer
config.num_attention_heads = config.n_head
config.hidden_size = config.n_embd
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.rotary_seq_len = config.sequence_len * 10
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
# ensure fp32 activations
self.transformer.wte.to(dtype=torch.bfloat16)
# following HF API expectations
self.post_init()
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
def _precompute_rotary_embeddings(self, seq_len: int, head_dim: int, base: int = 10000, device=None):
if device is None:
device = self.transformer.wte.weight.device
# Handle meta device case - use CPU as fallback
if device.type == 'meta':
device = torch.device('cpu')
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin
def _apply_softcap(self, logits: torch.Tensor) -> torch.Tensor:
softcap = 15
return softcap * torch.tanh(logits / softcap)
def _forward_impl(self, idx: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
x = self.transformer.wte(idx)
x = x.float()
x = _rms_norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = _rms_norm(x)
logits = self.lm_head(x)
return self._apply_softcap(logits)
def forward(self, input_ids: torch.Tensor, labels=None, loss_reduction: str = 'mean', **kwargs):
idx = input_ids
B, T = idx.size()
T0 = 0
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
logits = self._forward_impl(idx, cos_sin, kv_cache=None)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-1,
reduction=loss_reduction,
)
return {"loss": loss, "logits": logits}
class NanoGPTChat(NanoGPTModel):
"""Chat-optimized variant with HF-friendly generate and support for KV cache."""
def __init__(self, config: NanoGPTConfig):
super().__init__(config)
self.use_cache = getattr(config, "use_cache", True)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
def _expand_past_length(self, past_key_values):
if not past_key_values:
return 0
past_k, _ = past_key_values[0]
if past_k is None:
return 0
return past_k.size(2)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
loss_reduction: str = "mean",
**kwargs,
) -> CausalLMOutputWithPast:
idx = input_ids
B, T = idx.size()
use_cache = self.use_cache if use_cache is None else use_cache
past_length = self._expand_past_length(past_key_values)
T0 = past_length
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
x = self.transformer.wte(idx)
x = x.float()
x = _rms_norm(x)
presents = [] if use_cache else None
for layer_idx, block in enumerate(self.transformer.h):
past = None
if past_key_values is not None and past_key_values[layer_idx] is not None:
past = past_key_values[layer_idx]
attn_output, present = block.attn.forward_with_cache(
_rms_norm(x),
cos_sin,
past_key_value=past,
attention_mask=attention_mask,
use_cache=use_cache,
)
x = x + attn_output
x = x + block.mlp(_rms_norm(x))
if use_cache:
presents.append(present)
x = _rms_norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-1,
reduction=loss_reduction,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=presents,
)
|