File size: 4,867 Bytes
53b7a9c 43ea3b4 85e4754 43ea3b4 85e4754 43ea3b4 53b7a9c 43ea3b4 53b7a9c c9a9460 53b7a9c c9a9460 53b7a9c 43ea3b4 c9a9460 53b7a9c c9a9460 53b7a9c 43ea3b4 53b7a9c 43ea3b4 53b7a9c 43ea3b4 53b7a9c 43ea3b4 53b7a9c 43ea3b4 c9a9460 43ea3b4 c9a9460 53b7a9c 43ea3b4 c9a9460 43ea3b4 53b7a9c 43ea3b4 53b7a9c 43ea3b4 85e4754 43ea3b4 85e4754 c9a9460 43ea3b4 c9a9460 43ea3b4 c9a9460 43ea3b4 85e4754 43ea3b4 c9a9460 53b7a9c 85e4754 43ea3b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from transformers import PreTrainedModel, PretrainedConfig
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
import torch
import torch.nn as nn
class TRMConfig(PretrainedConfig):
model_type = "recursive_gpt"
def __init__(
self,
vocab_size=50257,
n_positions=1024,
n_embd=512,
n_physical_layers=3,
n_loops=8,
n_head=8,
activation_function="gelu_new",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
scale_attn_weights=True,
scale_attn_by_inverse_layer_idx=False,
reorder_and_upcast_attn=False,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_physical_layers = n_physical_layers
self.n_loops = n_loops
self.n_head = n_head
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.scale_attn_weights = scale_attn_weights
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.reorder_and_upcast_attn = reorder_and_upcast_attn
# Required for transformers compatibility
self.hidden_size = n_embd
self.num_attention_heads = n_head
self.num_hidden_layers = n_physical_layers
self.n_inner = None
self.is_encoder_decoder = False
class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
config_class = TRMConfig
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.config = config
# 1. Embeddings
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
# 2. Physical blocks - matching your saved model structure
self.physical_blocks = nn.ModuleList([
nn.ModuleDict({
"ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
"attn": GPT2Attention(config, layer_idx=i),
"ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
"mlp": GPT2MLP(4 * config.n_embd, config)
}) for i in range(config.n_physical_layers)
])
# 3. Final layer norm
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
# 4. Language modeling head
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights
self.post_init()
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
if input_ids is None:
return None
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Get embeddings
token_embeds = self.wte(input_ids)
pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
pos_embeds = self.wpe(pos_ids)
hidden_states = self.drop(token_embeds + pos_embeds)
# Apply recursive loops through physical blocks
for loop in range(self.config.n_loops):
block_idx = loop % self.config.n_physical_layers
block = self.physical_blocks[block_idx]
# Attention
ln_output = block["ln_1"](hidden_states)
attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0]
hidden_states = hidden_states + attn_output
# MLP
ln_output = block["ln_2"](hidden_states)
mlp_output = block["mlp"](ln_output)
hidden_states = hidden_states + mlp_output
# Final layer norm and projection
hidden_states = self.ln_f(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
hidden_states=hidden_states,
attentions=None,
cross_attentions=None
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def _reorder_cache(self, past, beam_idx):
return past
|