Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Literal | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| def gelu_approx(x): | |
| return F.gelu(x, approximate="tanh") | |
| class LinearWeights: | |
| weight: torch.Tensor | |
| bias: torch.Tensor | |
| def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: | |
| return F.linear(x, w.weight, w.bias) | |
| class LayerNormWeights: | |
| weight: torch.Tensor | |
| bias: torch.Tensor | |
| def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor: | |
| return F.layer_norm(x, w.bias.shape, w.weight, w.bias) | |
| class MLPWeights: | |
| fc1: LinearWeights | |
| fc2: LinearWeights | |
| act: Literal["gelu_approx"] = "gelu_approx" | |
| def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: | |
| x = linear(x, w.fc1) | |
| if w.act == "gelu_approx": | |
| x = gelu_approx(x) | |
| else: | |
| raise NotImplementedError(f"Activation function {w.act} not implemented.") | |
| x = linear(x, w.fc2) | |
| return x | |
| class AttentionWeights: | |
| qkv: LinearWeights | |
| proj: LinearWeights | |
| n_heads: int | |
| def attn(x: torch.Tensor, w: AttentionWeights) -> torch.Tensor: | |
| bsz, q_len, d_model = x.shape | |
| n_heads, head_dim = w.n_heads, d_model // w.n_heads | |
| q, k, v = [ | |
| t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) | |
| for t in linear(x, w.qkv).chunk(3, dim=-1) | |
| ] | |
| out = F.scaled_dot_product_attention(q, k, v) | |
| out = out.transpose(1, 2).reshape(bsz, q_len, d_model) | |
| out = linear(out, w.proj) | |
| return out | |