Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.nn import functional as F | |
| from .layers import layer_norm, linear, mlp | |
| from .rope import apply_rotary_emb, precompute_freqs_cis | |
| from .weights import AttentionWeights, TextModel, load_from_safetensors | |
| def text_encoder(input_ids: torch.Tensor, w: TextModel): | |
| return F.embedding(input_ids, w.wte) | |
| def attn_mask(pos, seq_len): | |
| """ | |
| Create an attention mask that aligns with the bottom right of the | |
| attention matrix. For example, if q_len = 2 and kv_len = 5, we want the | |
| following: | |
| 1 1 1 1 0 | |
| 1 1 1 1 1 | |
| and not this, which is what we get by default if we just set is_causal. | |
| 1 0 0 0 0 | |
| 1 1 0 0 0 | |
| """ | |
| mask = torch.ones(seq_len, pos + seq_len, dtype=torch.bool) | |
| mask[:, pos:] = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) | |
| mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions | |
| return mask | |
| def attn( | |
| x: torch.Tensor, | |
| w: AttentionWeights, | |
| freqs_cis: torch.Tensor, | |
| layer_kv_cache: torch.Tensor, | |
| ): | |
| bsz, q_len, d_model = x.shape | |
| pos = 0 if layer_kv_cache is None else layer_kv_cache.shape[3] | |
| n_heads, head_dim = w.n_heads, d_model // w.n_heads | |
| q, k, v = [ | |
| t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) | |
| for t in linear(x, w.qkv).chunk(3, dim=-1) | |
| ] | |
| position_ids = torch.arange(pos, pos + q_len, dtype=torch.long) | |
| q = apply_rotary_emb(q, freqs_cis, position_ids) | |
| k = apply_rotary_emb(k, freqs_cis, position_ids) | |
| k_, v_ = k, v | |
| if layer_kv_cache is not None: | |
| k = torch.cat([layer_kv_cache[0], k], dim=2) | |
| v = torch.cat([layer_kv_cache[1], v], dim=2) | |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask(pos, q_len)).to( | |
| # This type conversion isn't needed when running in PyTorch directly, but the | |
| # ONNX export runs attention in float32 because the attention mask is cast to | |
| # float32. | |
| x.dtype | |
| ) | |
| out = out.transpose(1, 2).reshape(bsz, q_len, d_model) | |
| out = linear(out, w.proj) | |
| return out, torch.stack([k_, v_]) | |
| def text_decoder( | |
| inputs_embeds: torch.Tensor, | |
| w: TextModel, | |
| kv_cache: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| ): | |
| hidden_BTC = inputs_embeds | |
| new_kv_cache = [torch.empty(0)] * len(w.blocks) | |
| for i, block in enumerate(w.blocks): | |
| l_in = layer_norm(hidden_BTC, block.ln) | |
| l_attn, new_kv_cache[i] = attn(l_in, block.attn, freqs_cis, kv_cache[i]) | |
| l_mlp = mlp(l_in, block.mlp) | |
| hidden_BTC = hidden_BTC + l_attn + l_mlp | |
| return hidden_BTC, torch.stack(new_kv_cache) | |
| def lm_head(hidden_BTC: torch.Tensor, w: TextModel): | |
| hidden_BC = hidden_BTC[:, -1, :] | |
| hidden_BC = layer_norm(hidden_BC, w.post_ln) | |
| logits = linear(hidden_BC, w.lm_head) | |
| return logits | |