from tqdm.notebook import tqdm import torch import torch.nn as nn import copy from .modeling_phi3_v import Phi3VForCausalLM, Phi3MLP from .configuration_phi3_v import Phi3VConfig from torch.optim import Adam from typing import Optional, Tuple from transformers import ( PreTrainedModel, AutoConfig, ) # Define the Gating Layer class GatingLayer(nn.Module): def __init__(self, input_dim, num_experts, k, layer_dtype=torch.float16): super(GatingLayer, self).__init__() self.num_experts = num_experts self.k = k self.gate = nn.Linear(input_dim, num_experts).to(dtype=layer_dtype) def forward(self, x): gate_scores = torch.softmax(self.gate(x), dim=-1) topk_values, topk_indices = torch.topk(gate_scores, self.k, dim=-1) return topk_values, topk_indices class MoE(nn.Module): def __init__(self, input_dim, experts, gating_layer, config): super(MoE, self).__init__() self.experts = nn.ModuleList(experts) self.gating_layer = gating_layer self.output_dim = config.hidden_size def forward(self, x): with torch.autocast(device_type="cuda", dtype=torch.float16): gate_values, gate_indices = self.gating_layer(x) batch_size, seq_length, _ = x.size() moe_output = torch.zeros( batch_size, seq_length, self.output_dim, dtype=self.gating_layer.gate.weight.dtype, device=x.device, ) for i in range(self.gating_layer.k): expert_outputs = [] for b in range(batch_size): for s in range(seq_length): expert_index = gate_indices[b, s, i] expert = self.experts[expert_index] up_states = expert.gate_up_proj(x[b, s].unsqueeze(0)) gate, up_states = up_states.chunk(2, dim=-1) up_states = up_states * expert.activation_fn(gate) expert_output = expert.down_proj(up_states) expert_outputs.append(expert_output) expert_outputs = torch.stack(expert_outputs, dim=0).view( batch_size, seq_length, -1 ) gate_values_i = ( gate_values[:, :, i].unsqueeze(-1).expand_as(expert_outputs) ) moe_output += gate_values_i * expert_outputs return moe_output # Define the ModifiedPhi3DecoderLayer Layer class ModifiedPhi3DecoderLayer(nn.Module): def __init__(self, original_layer, moe_layer): super(ModifiedPhi3DecoderLayer, self).__init__() self.self_attn = original_layer.self_attn self.mlp = moe_layer self.input_layernorm = original_layer.input_layernorm self.resid_attn_dropout = original_layer.resid_attn_dropout self.resid_mlp_dropout = original_layer.resid_mlp_dropout self.post_attention_layernorm = original_layer.post_attention_layernorm def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states with torch.autocast(device_type="cuda", dtype=hidden_states.dtype): hidden_states = self.input_layernorm(hidden_states) # Self Attention attn_outputs = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) attn_output = attn_outputs[0] hidden_states = residual + self.resid_attn_dropout(attn_output) residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.resid_mlp_dropout(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (attn_outputs[1],) if use_cache: outputs += (attn_outputs[2],) return outputs #Define Phi3VForCausalLMMoEConfig class Phi3VForCausalLMMoEConfig(Phi3VConfig): model_type = "phi3_v_moe" def __init__(self, config=None, k=1, num_expert_models=2, **kwargs): if config is not None: kwargs.update(config.to_dict()) super().__init__(**kwargs) self.k = k self.num_expert_models = num_expert_models self.architectures = "Phi3VForCausalLMMoE" self.auto_map = { "AutoConfig": "moe_phi3_v.Phi3VForCausalLMMoEConfig", "AutoModelForCausalLM": "moe_phi3_v.Phi3VForCausalLMMoE", } #Define MoE Model class Phi3VForCausalLMMoE(Phi3VForCausalLM): config_class = Phi3VForCausalLMMoEConfig def __init__( self, config, base_model=None, expert_models=None, layer_dtype=torch.bfloat16, **kwargs, ): super().__init__(config) self.layer_dtype = layer_dtype self.custom_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) k = self.config.k self.num_layers = len(base_model.model.layers) if base_model else 0 self.config.auto_map = { "AutoConfig": "moe_phi3_v.Phi3VForCausalLMMoEConfig", "AutoModelForCausalLM": "moe_phi3_v.Phi3VForCausalLMMoE", } self.model = base_model or Phi3VForCausalLM( self.config ) if base_model and expert_models: self.num_expert_models = len(expert_models) self._init_moe_layers(base_model, expert_models, k, layer_dtype) else: print( "Init function called and generating dummy experts: k=", k, "experts=", self.config.num_expert_models, ) num_dummy_experts = self.config.num_expert_models self._init_moe_layers_with_dummy_experts( self.model, k, num_dummy_experts, layer_dtype ) self.config.model_type = "phi3_v_moe" def _init_base_model(self): return PreTrainedModel(self.config) def _init_moe_layers(self, base_model, expert_models, k, layer_dtype): self.num_layers = len(base_model.model.layers) for i in tqdm(range(self.num_layers)): experts = [] for expert_model in expert_models: expert = copy.deepcopy(expert_model.model.layers[i].mlp).to( dtype=layer_dtype ) experts.append(expert) gating_layer = GatingLayer( input_dim=self.config.hidden_size, num_experts=len(experts), k=k, layer_dtype=layer_dtype, ) moe_layer = MoE( input_dim=self.config.hidden_size, experts=experts, gating_layer=gating_layer, config=self.config, ).to(dtype=layer_dtype) self.model.model.layers[i] = ModifiedPhi3DecoderLayer( self.model.model.layers[i], moe_layer ).to(dtype=layer_dtype) def _init_moe_layers_with_dummy_experts( self, base_model, k, num_dummy_experts, layer_dtype ): self.num_layers = len(base_model.model.layers) for i in tqdm(range(self.num_layers)): experts = [] for _ in range(num_dummy_experts): dummy_expert = Phi3MLP(self.config).to(dtype=layer_dtype) experts.append(dummy_expert) gating_layer = GatingLayer( input_dim=self.config.hidden_size, num_experts=len(experts), k=k, layer_dtype=layer_dtype, ) moe_layer = MoE( input_dim=self.config.hidden_size, experts=experts, gating_layer=gating_layer, config=self.config, ).to(dtype=layer_dtype) self.model.model.layers[i] = ModifiedPhi3DecoderLayer( self.model.model.layers[i], moe_layer ).to(dtype=layer_dtype) def forward(self, *args, **kwargs): return self.model.forward(*args, **kwargs) def generate(self, *args, **kwargs): return self.model.generate(*args, **kwargs) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # Initialize the model using the superclass method model = super(Phi3VForCausalLMMoE, cls).from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) return model def preselect_gating_layer_params(self, processor, prompts_per_expert, epochs = 1000): self.to(self.custom_device) self.eval() all_gating_layer_params = [] for layer_idx in tqdm(range(self.num_layers)): print(f"Training gating layer parameters for layer {layer_idx}") expert_embeddings = [] for prompts in prompts_per_expert: embeddings = [] for prompt in prompts: inputs = processor( text=prompt["text"], images=prompt["image"], return_tensors="pt" ).to(self.custom_device) with torch.no_grad(): if ( inputs.pixel_values is not None and inputs.image_sizes is not None ): outputs = self.model.model.vision_embed_tokens( inputs.input_ids, pixel_values=inputs.pixel_values, image_sizes=inputs.image_sizes, ).mean(dim=1) else: outputs = self.model.model.embed_tokens( inputs.input_ids ).mean(dim=1) embeddings.append(outputs) expert_embeddings.append(torch.stack(embeddings).mean(dim=0)) expert_embeddings = torch.stack(expert_embeddings).to(self.layer_dtype) class SimpleGatingLayer(nn.Module): def __init__(self, input_dim, num_experts, layer_dtype=torch.float16): super(SimpleGatingLayer, self).__init__() self.gate = nn.Linear(input_dim, num_experts).to(dtype=layer_dtype) def forward(self, x): return self.gate(x) input_dim = expert_embeddings.shape[2] num_experts = len(prompts_per_expert) gating_layer = SimpleGatingLayer( input_dim, num_experts, layer_dtype=self.layer_dtype ).to(self.custom_device) criterion = nn.CrossEntropyLoss() optimizer = Adam(gating_layer.parameters(), lr=1e-3) for epoch in tqdm(range(epochs), desc=f"Training Gating Layer {layer_idx}"): optimizer.zero_grad() expert_embeddings_reshaped = expert_embeddings.view( num_experts, input_dim ) outputs = gating_layer(expert_embeddings_reshaped) labels = torch.arange(num_experts).to(self.custom_device) loss = criterion(outputs, labels) loss.backward() optimizer.step() all_gating_layer_params.append(gating_layer.state_dict()) return all_gating_layer_params def set_gating_layer_params(self, gating_layer_params): for layer_idx, params in enumerate(gating_layer_params): self.model.model.layers[layer_idx].mlp.gating_layer.load_state_dict(params) def freeze_except_gating_layers(model): # freeze_except_gating_layers(moe_model) # Freeze all parameters for param in model.parameters(): param.requires_grad = False # Unfreeze gating layer parameters for layer in model.model.model.layers: for name, param in layer.mlp.gating_layer.named_parameters(): param.requires_grad = True def un_freeze_all(model): # freeze_except_gating_layers(moe_model) # Freeze all parameters for param in model.parameters(): param.requires_grad = True from transformers import AutoConfig AutoConfig.register("phi3_v_moe", Phi3VForCausalLMMoEConfig) from transformers.models.auto.modeling_auto import MODEL_MAPPING MODEL_MAPPING.update({"phi3_v_moe": Phi3VForCausalLMMoE})