dev7halo commited on
Commit
684363d
·
verified ·
1 Parent(s): 65b7c6c

Upload 2 files

Browse files
Files changed (2) hide show
  1. install_vllm_support.sh +112 -0
  2. kormo_moe_vllm.py +621 -0
install_vllm_support.sh ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # KORMo MoE vLLM Support Installation Script
3
+ # This script automatically installs the necessary files for vLLM support
4
+
5
+ set -e
6
+
7
+ echo "=========================================="
8
+ echo "KORMo MoE vLLM Support Installer"
9
+ echo "=========================================="
10
+ echo ""
11
+
12
+ # 1. vLLM 설치 확인
13
+ echo "Step 1: Checking vLLM installation..."
14
+ if ! python3 -c "import vllm" 2>/dev/null; then
15
+ echo "❌ vLLM is not installed."
16
+ echo "Installing vLLM..."
17
+ pip install vllm
18
+ echo "✅ vLLM installed successfully"
19
+ else
20
+ echo "✅ vLLM is already installed"
21
+ fi
22
+ echo ""
23
+
24
+ # 2. vLLM 설치 경로 찾기
25
+ echo "Step 2: Finding vLLM installation path..."
26
+ VLLM_PATH=$(python3 -c "import vllm; import os; print(os.path.dirname(vllm.__file__))")
27
+ echo "vLLM path: $VLLM_PATH"
28
+ echo ""
29
+
30
+ # 3. kormo_moe_vllm.py 다운로드
31
+ echo "Step 3: Downloading KORMo MoE vLLM implementation..."
32
+ if [ -f "kormo_moe_vllm.py" ]; then
33
+ echo "✅ kormo_moe_vllm.py found locally"
34
+ else
35
+ echo "Downloading from HuggingFace (dev7halo/KORMo-10B-sft-moe)..."
36
+ wget https://huggingface.co/dev7halo/KORMo-10B-sft-moe/resolve/main/kormo_moe_vllm.py -O kormo_moe_vllm.py 2>/dev/null || \
37
+ curl -L https://huggingface.co/dev7halo/KORMo-10B-sft-moe/resolve/main/kormo_moe_vllm.py -o kormo_moe_vllm.py 2>/dev/null || {
38
+ echo "❌ Failed to download. Please ensure kormo_moe_vllm.py is in the current directory."
39
+ echo "You can manually download it from:"
40
+ echo "https://huggingface.co/dev7halo/KORMo-10B-sft-moe/blob/main/kormo_moe_vllm.py"
41
+ exit 1
42
+ }
43
+ echo "✅ Downloaded successfully"
44
+ fi
45
+ echo ""
46
+
47
+ # 4. 파일 복사
48
+ echo "Step 4: Installing KORMo MoE model file..."
49
+ TARGET_PATH="$VLLM_PATH/model_executor/models/kormo_moe.py"
50
+ cp kormo_moe_vllm.py "$TARGET_PATH"
51
+ echo "✅ Copied to $TARGET_PATH"
52
+ echo ""
53
+
54
+ # 5. 레지스트리 수정
55
+ echo "Step 5: Registering KORMo MoE in vLLM..."
56
+ REGISTRY_PATH="$VLLM_PATH/model_executor/models/registry.py"
57
+
58
+ # 이미 등록되어 있는지 확인
59
+ if grep -q "KORMoMoeForCausalLM" "$REGISTRY_PATH"; then
60
+ echo "✅ KORMo MoE is already registered"
61
+ else
62
+ echo "Adding KORMo MoE to registry..."
63
+
64
+ # 백업 생성
65
+ cp "$REGISTRY_PATH" "$REGISTRY_PATH.backup"
66
+ echo "Created backup: $REGISTRY_PATH.backup"
67
+
68
+ # JambaForCausalLM 다음에 KORMo MoE 추가
69
+ if grep -q "JambaForCausalLM" "$REGISTRY_PATH"; then
70
+ sed -i '/\"JambaForCausalLM\"/a\ \"KORMoMoeForCausalLM\": (\"kormo_moe\", \"KORMoMoeForCausalLM\"),' "$REGISTRY_PATH"
71
+ echo "✅ KORMo MoE registered successfully"
72
+ else
73
+ echo "⚠️ Could not find JambaForCausalLM in registry."
74
+ echo "Please manually add the following line to $REGISTRY_PATH in _TEXT_GENERATION_MODELS:"
75
+ echo ' "KORMoMoeForCausalLM": ("kormo_moe", "KORMoMoeForCausalLM"),'
76
+ fi
77
+ fi
78
+ echo ""
79
+
80
+ # 6. 설치 확인
81
+ echo "Step 6: Verifying installation..."
82
+ python3 << EOF
83
+ try:
84
+ from vllm.model_executor.models.registry import ModelRegistry
85
+ if "KORMoMoeForCausalLM" in ModelRegistry.get_supported_archs():
86
+ print("✅ Installation successful! KORMo MoE is now supported in vLLM")
87
+ else:
88
+ print("❌ Registration verification failed")
89
+ exit(1)
90
+ except Exception as e:
91
+ print(f"❌ Error during verification: {e}")
92
+ exit(1)
93
+ EOF
94
+
95
+ echo ""
96
+ echo "=========================================="
97
+ echo "Installation Complete!"
98
+ echo "=========================================="
99
+ echo ""
100
+ echo "You can now use KORMo MoE with vLLM:"
101
+ echo ""
102
+ echo "Example usage:"
103
+ echo ""
104
+ echo "from vllm import LLM, SamplingParams"
105
+ echo ""
106
+ echo "# Load the model"
107
+ echo "llm = LLM(model='dev7halo/KORMo-10B-sft-moe', dtype='float16')"
108
+ echo ""
109
+ echo "# Generate text"
110
+ echo "prompts = ['안녕하세요']"
111
+ echo "outputs = llm.generate(prompts, SamplingParams(temperature=0.8, max_tokens=100))"
112
+ echo ""
kormo_moe_vllm.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ vLLM-compatible implementation of KORMo MoE
3
+
4
+ This file should be placed in: /usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/kormo_moe.py
5
+
6
+ Usage:
7
+ from vllm import LLM
8
+
9
+ llm = LLM(
10
+ model="/path/to/kormo_moe_model",
11
+ trust_remote_code=False, # Not needed with this implementation
12
+ dtype="float16",
13
+ )
14
+ """
15
+
16
+ from collections.abc import Iterable
17
+ from typing import Any, Optional, Union
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+
23
+ from vllm.attention import Attention
24
+ from vllm.compilation.decorators import support_torch_compile
25
+ from vllm.config import CacheConfig, VllmConfig
26
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
27
+ from vllm.logger import init_logger
28
+ from vllm.model_executor.layers.activation import SiluAndMul
29
+ from vllm.model_executor.layers.fused_moe import FusedMoE
30
+ from vllm.model_executor.layers.layernorm import RMSNorm
31
+ from vllm.model_executor.layers.linear import (
32
+ MergedColumnParallelLinear,
33
+ QKVParallelLinear,
34
+ ReplicatedLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
+ from vllm.model_executor.layers.quantization import QuantizationConfig
39
+ from vllm.model_executor.layers.rotary_embedding import get_rope
40
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
41
+ ParallelLMHead,
42
+ VocabParallelEmbedding,
43
+ )
44
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
46
+ from vllm.sequence import IntermediateTensors
47
+
48
+ try:
49
+ from transformers import PretrainedConfig
50
+ except ImportError:
51
+ # Fallback for environments without transformers
52
+ PretrainedConfig = object
53
+
54
+ from .interfaces import SupportsLoRA, SupportsPP
55
+ from .utils import (
56
+ AutoWeightsLoader,
57
+ extract_layer_index,
58
+ is_pp_missing_parameter,
59
+ make_empty_intermediate_tensors_factory,
60
+ make_layers,
61
+ maybe_prefix,
62
+ )
63
+
64
+ logger = init_logger(__name__)
65
+
66
+
67
+ class KORMoMoeConfig(PretrainedConfig):
68
+ """Configuration class for KORMo MoE"""
69
+
70
+ model_type = "kormo_moe"
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_size=112576,
75
+ hidden_size=6144,
76
+ intermediate_size=21504,
77
+ num_hidden_layers=48,
78
+ num_attention_heads=40,
79
+ num_key_value_heads=8,
80
+ hidden_act="silu",
81
+ max_position_embeddings=131072,
82
+ initializer_range=0.02,
83
+ rms_norm_eps=1e-05,
84
+ use_cache=True,
85
+ pad_token_id=None,
86
+ bos_token_id=0,
87
+ eos_token_id=1,
88
+ tie_word_embeddings=False,
89
+ rope_theta=500000.0,
90
+ attention_dropout=0.0,
91
+ rope_scaling=None,
92
+ head_dim=128,
93
+ # MoE specific
94
+ num_experts=2,
95
+ num_experts_per_tok=2,
96
+ moe_intermediate_size=None,
97
+ shared_expert_intermediate_size=None,
98
+ norm_topk_prob=True,
99
+ decoder_sparse_step=1,
100
+ **kwargs,
101
+ ):
102
+ self.vocab_size = vocab_size
103
+ self.max_position_embeddings = max_position_embeddings
104
+ self.hidden_size = hidden_size
105
+ self.intermediate_size = intermediate_size
106
+ self.num_hidden_layers = num_hidden_layers
107
+ self.num_attention_heads = num_attention_heads
108
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
109
+ self.hidden_act = hidden_act
110
+ self.initializer_range = initializer_range
111
+ self.rms_norm_eps = rms_norm_eps
112
+ self.use_cache = use_cache
113
+ self.rope_theta = rope_theta
114
+ self.rope_scaling = rope_scaling
115
+ self.attention_dropout = attention_dropout
116
+ self.head_dim = head_dim or (self.hidden_size // self.num_attention_heads)
117
+
118
+ # MoE specific
119
+ self.num_experts = num_experts
120
+ self.num_experts_per_tok = num_experts_per_tok
121
+ self.moe_intermediate_size = (
122
+ moe_intermediate_size if moe_intermediate_size is not None else intermediate_size
123
+ )
124
+ self.shared_expert_intermediate_size = shared_expert_intermediate_size
125
+ self.norm_topk_prob = norm_topk_prob
126
+ self.decoder_sparse_step = decoder_sparse_step
127
+
128
+ super().__init__(
129
+ pad_token_id=pad_token_id,
130
+ bos_token_id=bos_token_id,
131
+ eos_token_id=eos_token_id,
132
+ tie_word_embeddings=tie_word_embeddings,
133
+ **kwargs,
134
+ )
135
+
136
+
137
+ class KORMoMoEMLP(nn.Module):
138
+ """MLP for KORMo, used for shared expert"""
139
+
140
+ def __init__(
141
+ self,
142
+ hidden_size: int,
143
+ intermediate_size: int,
144
+ hidden_act: str,
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ reduce_results: bool = True,
147
+ ) -> None:
148
+ super().__init__()
149
+ self.gate_up_proj = MergedColumnParallelLinear(
150
+ hidden_size,
151
+ [intermediate_size] * 2,
152
+ bias=False,
153
+ quant_config=quant_config,
154
+ )
155
+ self.down_proj = RowParallelLinear(
156
+ intermediate_size,
157
+ hidden_size,
158
+ bias=False,
159
+ quant_config=quant_config,
160
+ reduce_results=reduce_results,
161
+ )
162
+ if hidden_act != "silu":
163
+ raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported.")
164
+ self.act_fn = SiluAndMul()
165
+
166
+ def forward(self, x):
167
+ gate_up, _ = self.gate_up_proj(x)
168
+ x = self.act_fn(gate_up)
169
+ x, _ = self.down_proj(x)
170
+ return x
171
+
172
+
173
+ class KORMoSparseMoeBlock(nn.Module):
174
+ """KORMo Sparse MoE Block optimized for vLLM"""
175
+
176
+ def __init__(
177
+ self,
178
+ config: KORMoMoeConfig,
179
+ quant_config: Optional[QuantizationConfig] = None,
180
+ prefix: str = "",
181
+ ):
182
+ super().__init__()
183
+ self.tp_size = get_tensor_model_parallel_world_size()
184
+
185
+ if self.tp_size > config.num_experts:
186
+ raise ValueError(
187
+ f"Tensor parallel size {self.tp_size} is greater than "
188
+ f"the number of experts {config.num_experts}."
189
+ )
190
+
191
+ # Use vLLM's FusedMoE for optimized expert routing
192
+ self.experts = FusedMoE(
193
+ num_experts=config.num_experts,
194
+ top_k=config.num_experts_per_tok,
195
+ hidden_size=config.hidden_size,
196
+ intermediate_size=config.moe_intermediate_size,
197
+ reduce_results=False,
198
+ renormalize=config.norm_topk_prob,
199
+ quant_config=quant_config,
200
+ prefix=f"{prefix}.experts",
201
+ )
202
+
203
+ # Router/gate
204
+ self.gate = ReplicatedLinear(
205
+ config.hidden_size,
206
+ config.num_experts,
207
+ bias=False,
208
+ quant_config=None,
209
+ )
210
+
211
+ # Shared expert (optional)
212
+ if config.shared_expert_intermediate_size and config.shared_expert_intermediate_size > 0:
213
+ self.shared_expert = KORMoMoEMLP(
214
+ hidden_size=config.hidden_size,
215
+ intermediate_size=config.shared_expert_intermediate_size,
216
+ hidden_act=config.hidden_act,
217
+ quant_config=quant_config,
218
+ reduce_results=self.experts.must_reduce_shared_expert_outputs(),
219
+ )
220
+ self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False)
221
+ else:
222
+ self.shared_expert = None
223
+ self.shared_expert_gate = None
224
+
225
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
226
+ # NOTE: hidden_states can have either 1D or 2D shape.
227
+ orig_shape = hidden_states.shape
228
+ hidden_dim = hidden_states.shape[-1]
229
+ hidden_states = hidden_states.view(-1, hidden_dim)
230
+
231
+ # Shared expert처리
232
+ shared_output = None
233
+ if self.shared_expert is not None:
234
+ shared_output = self.shared_expert(hidden_states)
235
+ if self.shared_expert_gate is not None:
236
+ shared_output = F.sigmoid(
237
+ self.shared_expert_gate(hidden_states)
238
+ ) * shared_output
239
+
240
+ # Router logits: (num_tokens, n_experts)
241
+ router_logits, _ = self.gate(hidden_states)
242
+
243
+ # FusedMoE에서 expert routing 수행
244
+ final_hidden_states = self.experts(
245
+ hidden_states=hidden_states,
246
+ router_logits=router_logits,
247
+ )
248
+
249
+ # Shared expert 결과 추가
250
+ if shared_output is not None:
251
+ final_hidden_states = final_hidden_states + shared_output
252
+
253
+ # Tensor parallel reduction
254
+ if self.tp_size > 1:
255
+ final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
256
+ final_hidden_states
257
+ )
258
+
259
+ return final_hidden_states.view(orig_shape)
260
+
261
+
262
+ class KORMoMoeAttention(nn.Module):
263
+ """KORMo MoE Attention mechanism"""
264
+
265
+ def __init__(
266
+ self,
267
+ hidden_size: int,
268
+ num_heads: int,
269
+ num_kv_heads: int,
270
+ rope_theta: float = 500000,
271
+ rope_scaling: Optional[dict[str, Any]] = None,
272
+ max_position_embeddings: int = 131072,
273
+ cache_config: Optional[CacheConfig] = None,
274
+ quant_config: Optional[QuantizationConfig] = None,
275
+ prefix: str = "",
276
+ ) -> None:
277
+ super().__init__()
278
+ self.hidden_size = hidden_size
279
+ tp_size = get_tensor_model_parallel_world_size()
280
+
281
+ self.total_num_heads = num_heads
282
+ assert self.total_num_heads % tp_size == 0
283
+ self.num_heads = self.total_num_heads // tp_size
284
+
285
+ self.total_num_kv_heads = num_kv_heads
286
+ if self.total_num_kv_heads >= tp_size:
287
+ assert self.total_num_kv_heads % tp_size == 0
288
+ else:
289
+ assert tp_size % self.total_num_kv_heads == 0
290
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
291
+
292
+ self.head_dim = hidden_size // self.total_num_heads
293
+ self.q_size = self.num_heads * self.head_dim
294
+ self.kv_size = self.num_kv_heads * self.head_dim
295
+ self.scaling = self.head_dim**-0.5
296
+ self.rope_theta = rope_theta
297
+ self.max_position_embeddings = max_position_embeddings
298
+
299
+ self.qkv_proj = QKVParallelLinear(
300
+ hidden_size,
301
+ self.head_dim,
302
+ self.total_num_heads,
303
+ self.total_num_kv_heads,
304
+ bias=False,
305
+ quant_config=quant_config,
306
+ )
307
+
308
+ self.o_proj = RowParallelLinear(
309
+ self.total_num_heads * self.head_dim,
310
+ hidden_size,
311
+ bias=False,
312
+ quant_config=quant_config,
313
+ )
314
+
315
+ self.rotary_emb = get_rope(
316
+ self.head_dim,
317
+ rotary_dim=self.head_dim,
318
+ max_position=max_position_embeddings,
319
+ base=rope_theta,
320
+ rope_scaling=rope_scaling,
321
+ )
322
+
323
+ self.attn = Attention(
324
+ self.num_heads,
325
+ self.head_dim,
326
+ self.scaling,
327
+ num_kv_heads=self.num_kv_heads,
328
+ cache_config=cache_config,
329
+ quant_config=quant_config,
330
+ prefix=f"{prefix}.attn",
331
+ )
332
+
333
+ def forward(
334
+ self,
335
+ positions: torch.Tensor,
336
+ hidden_states: torch.Tensor,
337
+ ) -> torch.Tensor:
338
+ qkv, _ = self.qkv_proj(hidden_states)
339
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
340
+ q, k = self.rotary_emb(positions, q, k)
341
+ attn_output = self.attn(q, k, v)
342
+ output, _ = self.o_proj(attn_output)
343
+ return output
344
+
345
+
346
+ class KORMoMoeDecoderLayer(nn.Module):
347
+ """KORMo MoE Decoder Layer"""
348
+
349
+ def __init__(
350
+ self,
351
+ config: KORMoMoeConfig,
352
+ cache_config: Optional[CacheConfig] = None,
353
+ quant_config: Optional[QuantizationConfig] = None,
354
+ prefix: str = "",
355
+ ) -> None:
356
+ super().__init__()
357
+ self.hidden_size = config.hidden_size
358
+
359
+ # Attention
360
+ self.self_attn = KORMoMoeAttention(
361
+ hidden_size=self.hidden_size,
362
+ num_heads=config.num_attention_heads,
363
+ num_kv_heads=config.num_key_value_heads,
364
+ rope_theta=config.rope_theta,
365
+ rope_scaling=config.rope_scaling,
366
+ max_position_embeddings=config.max_position_embeddings,
367
+ cache_config=cache_config,
368
+ quant_config=quant_config,
369
+ prefix=f"{prefix}.self_attn",
370
+ )
371
+
372
+ # MoE MLP
373
+ self.mlp = KORMoSparseMoeBlock(
374
+ config=config,
375
+ quant_config=quant_config,
376
+ prefix=f"{prefix}.mlp",
377
+ )
378
+
379
+ # LayerNorms (using KORMo naming convention)
380
+ self.pre_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
+ self.pre_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
382
+
383
+ def forward(
384
+ self,
385
+ positions: torch.Tensor,
386
+ hidden_states: torch.Tensor,
387
+ residual: Optional[torch.Tensor],
388
+ ) -> torch.Tensor:
389
+ # Self Attention
390
+ if residual is None:
391
+ residual = hidden_states
392
+ hidden_states = self.pre_attention_layernorm(hidden_states)
393
+ else:
394
+ hidden_states, residual = self.pre_attention_layernorm(hidden_states, residual)
395
+
396
+ hidden_states = self.self_attn(
397
+ positions=positions,
398
+ hidden_states=hidden_states,
399
+ )
400
+
401
+ # MoE MLP
402
+ hidden_states, residual = self.pre_mlp_layernorm(hidden_states, residual)
403
+ hidden_states = self.mlp(hidden_states)
404
+
405
+ return hidden_states, residual
406
+
407
+
408
+ @support_torch_compile
409
+ class KORMoMoeModel(nn.Module):
410
+ """KORMo MoE Model"""
411
+
412
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
413
+ super().__init__()
414
+
415
+ config = vllm_config.model_config.hf_config
416
+ cache_config = vllm_config.cache_config
417
+ quant_config = vllm_config.quant_config
418
+
419
+ self.vocab_size = config.vocab_size
420
+ self.config = config
421
+
422
+ self.embed_tokens = VocabParallelEmbedding(
423
+ config.vocab_size,
424
+ config.hidden_size,
425
+ )
426
+
427
+ self.start_layer, self.end_layer, self.layers = make_layers(
428
+ config.num_hidden_layers,
429
+ lambda prefix: KORMoMoeDecoderLayer(
430
+ config=config,
431
+ cache_config=cache_config,
432
+ quant_config=quant_config,
433
+ prefix=prefix,
434
+ ),
435
+ prefix=f"{prefix}.layers",
436
+ )
437
+
438
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
439
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
440
+ ["hidden_states", "residual"], config.hidden_size
441
+ )
442
+
443
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
444
+ return self.embed_tokens(input_ids)
445
+
446
+ def forward(
447
+ self,
448
+ input_ids: torch.Tensor,
449
+ positions: torch.Tensor,
450
+ intermediate_tensors: Optional[IntermediateTensors] = None,
451
+ inputs_embeds: Optional[torch.Tensor] = None,
452
+ ) -> Union[torch.Tensor, IntermediateTensors]:
453
+ if get_pp_group().is_first_rank:
454
+ if inputs_embeds is not None:
455
+ hidden_states = inputs_embeds
456
+ else:
457
+ hidden_states = self.get_input_embeddings(input_ids)
458
+ residual = None
459
+ else:
460
+ assert intermediate_tensors is not None
461
+ hidden_states = intermediate_tensors["hidden_states"]
462
+ residual = intermediate_tensors["residual"]
463
+
464
+ for layer in self.layers[self.start_layer : self.end_layer]:
465
+ hidden_states, residual = layer(positions, hidden_states, residual)
466
+
467
+ if not get_pp_group().is_last_rank:
468
+ return IntermediateTensors({
469
+ "hidden_states": hidden_states,
470
+ "residual": residual,
471
+ })
472
+
473
+ hidden_states, _ = self.norm(hidden_states, residual)
474
+ return hidden_states
475
+
476
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
477
+ """Return expert parameter mapping for weight loading"""
478
+ return FusedMoE.make_expert_params_mapping(
479
+ ckpt_gate_proj_name="gate_proj",
480
+ ckpt_down_proj_name="down_proj",
481
+ ckpt_up_proj_name="up_proj",
482
+ num_experts=self.config.num_experts,
483
+ )
484
+
485
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
486
+ stacked_params_mapping = [
487
+ # (param_name, shard_name, shard_id)
488
+ ("qkv_proj", "q_proj", "q"),
489
+ ("qkv_proj", "k_proj", "k"),
490
+ ("qkv_proj", "v_proj", "v"),
491
+ ("gate_up_proj", "gate_proj", 0),
492
+ ("gate_up_proj", "up_proj", 1),
493
+ ]
494
+
495
+ params_dict = dict(self.named_parameters())
496
+ loaded_params: set[str] = set()
497
+ expert_params_mapping = self.get_expert_mapping()
498
+
499
+ for name, loaded_weight in weights:
500
+ # Handle stacked parameters
501
+ for param_name, weight_name, shard_id in stacked_params_mapping:
502
+ if weight_name not in name:
503
+ continue
504
+ if "mlp.experts" in name:
505
+ continue
506
+ name = name.replace(weight_name, param_name)
507
+ if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict:
508
+ continue
509
+ if is_pp_missing_parameter(name, self):
510
+ continue
511
+ if name not in params_dict:
512
+ continue
513
+
514
+ param = params_dict[name]
515
+ weight_loader = param.weight_loader
516
+ weight_loader(param, loaded_weight, shard_id)
517
+ break
518
+ else:
519
+ # Handle expert parameters
520
+ for mapping in expert_params_mapping:
521
+ param_name, weight_name, expert_id, shard_id = mapping
522
+ if weight_name not in name:
523
+ continue
524
+ name = name.replace(weight_name, param_name)
525
+
526
+ if is_pp_missing_parameter(name, self):
527
+ continue
528
+ if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict:
529
+ continue
530
+
531
+ param = params_dict[name]
532
+ weight_loader = param.weight_loader
533
+ weight_loader(
534
+ param,
535
+ loaded_weight,
536
+ name,
537
+ shard_id=shard_id,
538
+ expert_id=expert_id,
539
+ )
540
+ break
541
+ else:
542
+ # Handle regular parameters
543
+ if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict:
544
+ continue
545
+ if is_pp_missing_parameter(name, self):
546
+ continue
547
+
548
+ # Fix gate weight naming: gate.linear.weight -> gate.weight
549
+ if ".gate.linear.weight" in name:
550
+ name = name.replace(".gate.linear.weight", ".gate.weight")
551
+
552
+ if name not in params_dict:
553
+ logger.warning(f"Parameter {name} not found in model")
554
+ continue
555
+
556
+ param = params_dict[name]
557
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
558
+ weight_loader(param, loaded_weight)
559
+
560
+ loaded_params.add(name)
561
+
562
+ return loaded_params
563
+
564
+
565
+ class KORMoMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
566
+ """KORMo MoE for Causal Language Modeling"""
567
+
568
+ fall_back_to_pt_during_load = False
569
+ packed_modules_mapping = {
570
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
571
+ "gate_up_proj": ["gate_proj", "up_proj"],
572
+ }
573
+
574
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
575
+ super().__init__()
576
+ config = vllm_config.model_config.hf_config
577
+ quant_config = vllm_config.quant_config
578
+
579
+ self.config = config
580
+ self.quant_config = quant_config
581
+
582
+ self.model = KORMoMoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
583
+ self.lm_head = ParallelLMHead(
584
+ config.vocab_size,
585
+ config.hidden_size,
586
+ quant_config=quant_config,
587
+ )
588
+
589
+ if self.config.tie_word_embeddings:
590
+ self.lm_head.weight = self.model.embed_tokens.weight
591
+
592
+ self.logits_processor = LogitsProcessor(config.vocab_size)
593
+ self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
594
+
595
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
596
+ return self.model.get_input_embeddings(input_ids)
597
+
598
+ def forward(
599
+ self,
600
+ input_ids: torch.Tensor,
601
+ positions: torch.Tensor,
602
+ intermediate_tensors: Optional[IntermediateTensors] = None,
603
+ inputs_embeds: Optional[torch.Tensor] = None,
604
+ ) -> Union[torch.Tensor, IntermediateTensors]:
605
+ hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
606
+ return hidden_states
607
+
608
+ def compute_logits(
609
+ self,
610
+ hidden_states: torch.Tensor,
611
+ sampling_metadata: SamplingMetadata,
612
+ ) -> Optional[torch.Tensor]:
613
+ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
614
+ return logits
615
+
616
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
617
+ loader = AutoWeightsLoader(self)
618
+ return loader.load_weights(weights)
619
+
620
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
621
+ return self.model.get_expert_mapping()