File size: 4,867 Bytes
53b7a9c
43ea3b4
85e4754
43ea3b4
85e4754
43ea3b4
 
 
53b7a9c
 
 
 
 
 
 
 
43ea3b4
 
53b7a9c
c9a9460
 
53b7a9c
c9a9460
 
 
 
 
 
53b7a9c
 
 
 
 
 
 
43ea3b4
c9a9460
 
53b7a9c
c9a9460
 
 
 
 
53b7a9c
43ea3b4
53b7a9c
43ea3b4
53b7a9c
43ea3b4
 
53b7a9c
 
 
43ea3b4
53b7a9c
 
 
 
 
 
 
 
 
 
43ea3b4
c9a9460
43ea3b4
 
 
 
 
 
c9a9460
53b7a9c
43ea3b4
c9a9460
43ea3b4
 
53b7a9c
 
43ea3b4
53b7a9c
 
43ea3b4
 
 
85e4754
43ea3b4
85e4754
c9a9460
43ea3b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9a9460
 
 
 
 
 
 
43ea3b4
c9a9460
 
 
 
 
43ea3b4
85e4754
43ea3b4
c9a9460
53b7a9c
 
 
85e4754
43ea3b4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

from transformers import PreTrainedModel, PretrainedConfig
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
import torch
import torch.nn as nn

class TRMConfig(PretrainedConfig):
    model_type = "recursive_gpt"

    def __init__(
        self,
        vocab_size=50257,
        n_positions=1024,
        n_embd=512,
        n_physical_layers=3,
        n_loops=8,
        n_head=8,
        activation_function="gelu_new",
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
        layer_norm_epsilon=1e-5,
        scale_attn_weights=True,
        scale_attn_by_inverse_layer_idx=False,
        reorder_and_upcast_attn=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_physical_layers = n_physical_layers
        self.n_loops = n_loops
        self.n_head = n_head
        self.activation_function = activation_function
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.scale_attn_weights = scale_attn_weights
        self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
        self.reorder_and_upcast_attn = reorder_and_upcast_attn

        # Required for transformers compatibility
        self.hidden_size = n_embd
        self.num_attention_heads = n_head
        self.num_hidden_layers = n_physical_layers
        self.n_inner = None
        self.is_encoder_decoder = False

class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
    config_class = TRMConfig
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # 1. Embeddings
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)

        # 2. Physical blocks - matching your saved model structure
        self.physical_blocks = nn.ModuleList([
            nn.ModuleDict({
                "ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
                "attn": GPT2Attention(config, layer_idx=i),
                "ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
                "mlp": GPT2MLP(4 * config.n_embd, config)
            }) for i in range(config.n_physical_layers)
        ])

        # 3. Final layer norm
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

        # 4. Language modeling head
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Initialize weights
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        if input_ids is None:
            return None

        batch_size, seq_len = input_ids.shape
        device = input_ids.device

        # Get embeddings
        token_embeds = self.wte(input_ids)
        pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
        pos_embeds = self.wpe(pos_ids)
        hidden_states = self.drop(token_embeds + pos_embeds)

        # Apply recursive loops through physical blocks
        for loop in range(self.config.n_loops):
            block_idx = loop % self.config.n_physical_layers
            block = self.physical_blocks[block_idx]

            # Attention
            ln_output = block["ln_1"](hidden_states)
            attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0]
            hidden_states = hidden_states + attn_output

            # MLP
            ln_output = block["ln_2"](hidden_states)
            mlp_output = block["mlp"](ln_output)
            hidden_states = hidden_states + mlp_output

        # Final layer norm and projection
        hidden_states = self.ln_f(hidden_states)
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=logits,
            hidden_states=hidden_states,
            attentions=None,
            cross_attentions=None
        )

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

    def _reorder_cache(self, past, beam_idx):
        return past