KaiYinTAMU commited on
Commit
44779f5
·
verified ·
1 Parent(s): fa0aa6d

Upload 2 files

Browse files
bidirectional_qwen3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .bidirectional_qwen3 import *
bidirectional_qwen3/bidirectional_qwen3.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Any
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers.cache_utils import Cache # kept for potential future use
8
+ from transformers.models.qwen3.modeling_qwen3 import (
9
+ Qwen3Attention,
10
+ Qwen3DecoderLayer,
11
+ Qwen3MLP,
12
+ Qwen3RMSNorm,
13
+ Qwen3Model,
14
+ Qwen3ForCausalLM,
15
+ Qwen3PreTrainedModel,
16
+ )
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+ from transformers.modeling_utils import PreTrainedModel
20
+
21
+ try:
22
+ from peft import PeftModel
23
+ except ImportError:
24
+ PeftModel = Any # soft dependency
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # 1) Bidirectional attention: disable causal masking & sliding window
30
+ # ---------------------------------------------------------------------------
31
+ class ModifiedQwen3Attention(Qwen3Attention):
32
+ """Full-context self-attention (no causal mask)."""
33
+
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self.is_causal = False
37
+ self.sliding_window = None
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # 2) Decoder layer using the bidirectional attention module
42
+ # ---------------------------------------------------------------------------
43
+ class ModifiedQwen3DecoderLayer(Qwen3DecoderLayer):
44
+ """Decoder layer with full-context attention."""
45
+
46
+ def __init__(self, config: PretrainedConfig, layer_idx: int):
47
+ super().__init__(config, layer_idx)
48
+ self.self_attn = ModifiedQwen3Attention(config=config, layer_idx=layer_idx)
49
+ self.attention_type = "full_attention"
50
+ self.sliding_window = None
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # 3) Backbone: Qwen-3 with bidirectional self-attention
55
+ # ---------------------------------------------------------------------------
56
+ class Qwen3BiModel(Qwen3Model):
57
+ """Qwen-3 backbone whose self-attention is bidirectional."""
58
+
59
+ _no_split_modules = ["ModifiedQwen3DecoderLayer"]
60
+
61
+ def __init__(self, config: PretrainedConfig):
62
+ super().__init__(config)
63
+ self.layers = nn.ModuleList(
64
+ [ModifiedQwen3DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
65
+ )
66
+ self.has_sliding_layers = False
67
+
68
+ @staticmethod
69
+ def _build_pad_bias(pad_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
70
+ """[B,L] -> additive bias [B,1,1,L] with -inf on padding."""
71
+ neg_inf = torch.finfo(dtype).min
72
+ bias = (~pad_mask.bool()).to(dtype) * neg_inf
73
+ return bias[:, None, None, :]
74
+
75
+ def forward(
76
+ self,
77
+ input_ids: Optional[torch.LongTensor] = None,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ **kwargs,
80
+ ):
81
+ # Default to keep-all if no mask is provided
82
+ if attention_mask is None:
83
+ if input_ids is None:
84
+ raise ValueError("Either attention_mask or input_ids must be provided.")
85
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
86
+
87
+ pad_bias = self._build_pad_bias(attention_mask, self.embed_tokens.weight.dtype)
88
+ # Dict mask tells parent to skip causal-mask generation
89
+ attn_mask_dict = {"full_attention": pad_bias}
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attn_mask_dict,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # 4) Task head: MNTP (masked next-token) — no generation API
100
+ # ---------------------------------------------------------------------------
101
+ class Qwen3BiForMNTP(Qwen3ForCausalLM):
102
+ """Bidirectional Qwen-3 with LM head for masked-token objectives."""
103
+
104
+ def __init__(self, config: PretrainedConfig):
105
+ # Bypass parent __init__ to wire a custom backbone
106
+ Qwen3PreTrainedModel.__init__(self, config)
107
+
108
+ self.model = Qwen3BiModel(config)
109
+ self.vocab_size = config.vocab_size
110
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
111
+
112
+ self.post_init()
113
+
114
+ def generate(self, *args, **kwargs): # type: ignore[override]
115
+ """Disabled: bidirectional backbone is not autoregressive."""
116
+ raise NotImplementedError(
117
+ "generate() is disabled: this backbone is bidirectional and not autoregressive."
118
+ )
119
+
120
+ # -------- PEFT helpers --------
121
+ def get_model_for_peft(self):
122
+ return self.model
123
+
124
+ def set_model_for_peft(self, model: PeftModel): # type: ignore[override]
125
+ self.model = model
126
+
127
+ def save_peft_model(self, path: str):
128
+ if isinstance(self.model, PeftModel): # type: ignore[arg-type]
129
+ self.model.save_pretrained(path)
130
+ else:
131
+ raise ValueError("Backbone is not a PEFT model; nothing to save.")