Spaces:
Runtime error
Runtime error
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """ff_layer.py | |
| This module contains the implementation of the feedforward layers. | |
| Supported ff_layer_type: | |
| 'mlp': Multi-Layer Perceptron | |
| 'gmlp': Gated Multi-Layer Perceptron, simplified version of Mixtral Expert with num_experts=1 and top_k=1. | |
| This is not the spatial gating MLP (https://arxiv.org/abs/2105.08050). | |
| 'moe': Mixtral of Experts, modified from the original source code: | |
| https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/mixtral/modeling_mixtral.py | |
| Usage: | |
| from model.ff_layer import get_ff_layer | |
| config = PerceiverTFConfig() # or any type of PretrainedConfig() | |
| config.ff_layer_type = 'moe' # or 'mlp' | |
| config.moe_num_experts = 4 | |
| config.moe_topk = 2 | |
| config.hidden_act = 'gelu' # or any type of activation function, e.g., 'silu' | |
| ff_layer = get_ff_layer(config, input_size, widening_factor) | |
| What ff_layer returns: | |
| - It returns (hidden_states, router_logits) for MoE and (hidden_states, None) for MLP. | |
| - router_logits has the shape of (batch_size * sequence_length, n_experts) for MoE. | |
| """ | |
| from typing import Any, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.activations import ACT2FN | |
| from model.ops import get_layer_norm | |
| from model.ops import optional_compiler_disable, optional_compiler_dynamic | |
| class MixtralBlockSparseTop2MLP(nn.Module): | |
| """ | |
| The Gated Multilayer Perceptron (GMLP) used in Mixtral of Experts (MoE). | |
| """ | |
| def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): | |
| super().__init__() | |
| self.hidden_dim = input_size | |
| self.ffn_dim = int(input_size * widening_factor) | |
| self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | |
| self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) | |
| self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.gate(hidden_states) | |
| current_hidden_states = self.w2(current_hidden_states) | |
| return current_hidden_states | |
| class MixtralSparseMoeBlock(nn.Module): | |
| """ | |
| This implementation is | |
| strictly equivalent to standard MoE with full capacity (no | |
| dropped tokens). It's faster since it formulates MoE operations | |
| in terms of block-sparse operations to accomodate imbalanced | |
| assignments of tokens to experts, whereas standard MoE either | |
| (1) drop tokens at the cost of reduced performance or (2) set | |
| capacity factor to number of experts and thus waste computation | |
| and memory on padding. | |
| """ | |
| def __init__(self, config, input_size: int, widening_factor: int): | |
| super().__init__() | |
| self.hidden_dim = input_size | |
| self.widening_factor = widening_factor | |
| self.num_experts = config.moe_num_experts | |
| self.top_k = config.moe_topk | |
| # gating | |
| self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) | |
| self.experts = nn.ModuleList( | |
| [MixtralBlockSparseTop2MLP(config, self.hidden_dim, self.widening_factor) for _ in range(self.num_experts)]) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """ """ | |
| batch_size, sequence_length, hidden_dim = hidden_states.shape | |
| hidden_states = hidden_states.view(-1, hidden_dim) | |
| # router_logits: (batch * sequence_length, n_experts) | |
| router_logits = self.gate(hidden_states) | |
| routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | |
| routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) | |
| routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | |
| # we cast back to the input dtype | |
| routing_weights = routing_weights.to(hidden_states.dtype) | |
| final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), | |
| dtype=hidden_states.dtype, | |
| device=hidden_states.device) | |
| # One hot encode the selected experts to create an expert mask | |
| # this will be used to easily index which expert is going to be sollicitated | |
| expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) | |
| # Loop over all available experts in the model and perform the computation on each expert | |
| for expert_idx in range(self.num_experts): | |
| expert_layer = self.experts[expert_idx] | |
| idx, top_x = torch.where(expert_mask[expert_idx]) | |
| if top_x.shape[0] == 0: | |
| continue | |
| # in torch it is faster to index using lists than torch tensors | |
| top_x_list = top_x.tolist() | |
| idx_list = idx.tolist() | |
| # Index the correct hidden states and compute the expert hidden state for | |
| # the current expert. We need to make sure to multiply the output hidden | |
| # states by `routing_weights` on the corresponding tokens (top-1 and top-2) | |
| current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) | |
| current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] | |
| # However `index_add_` only support torch tensors for indexing so we'll use | |
| # the `top_x` tensor here. | |
| final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) | |
| final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) | |
| return final_hidden_states, router_logits | |
| class MLP(nn.Module): | |
| """A Standard Transformer-style dense module to follow attention.""" | |
| def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): | |
| super().__init__() | |
| self.dense1 = nn.Linear(input_size, widening_factor * input_size) | |
| self.dense2 = nn.Linear(widening_factor * input_size, input_size) | |
| if isinstance(config.hidden_act, str): | |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] | |
| else: | |
| self.intermediate_act_fn = config.hidden_act | |
| def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, Any]: | |
| hidden_states = self.dense1(hidden_states) | |
| hidden_states = self.intermediate_act_fn(hidden_states) | |
| hidden_states = self.dense2(hidden_states) | |
| return hidden_states, None | |
| class SimpleGMLP(nn.Module): | |
| """A Simple Gated Multilayer Perceptron (aka. 'gmlp'), without the spatial gating mechanism. | |
| Note that this is not the spatial gating MLP (https://arxiv.org/abs/2105.08050). | |
| - A simplified MLP w/ gating mechanism adapted from Mixtral Expert, as when | |
| the number of experts and top_k are both set to 1.) | |
| - Added a dropout layer. | |
| - This was also used in T5 v1.1. | |
| """ | |
| def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): | |
| super().__init__() | |
| self.hidden_dim = input_size | |
| self.ffn_dim = int(input_size * widening_factor) | |
| self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | |
| self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) | |
| self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| self.dropout1 = nn.Dropout(config.dropout_rate) | |
| self.dropout2 = nn.Dropout(config.dropout_rate) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.gate(hidden_states) | |
| current_hidden_states = self.dropout1(current_hidden_states) | |
| current_hidden_states = self.w2(current_hidden_states) | |
| current_hidden_states = self.dropout2( | |
| current_hidden_states) # Residual connection is applied outside of this module. | |
| return current_hidden_states, None | |
| def get_ff_layer(config: PretrainedConfig, input_size: int, widening_factor: int): | |
| if config.ff_layer_type == 'moe': | |
| assert hasattr(config, 'moe_num_experts') and hasattr(config, 'moe_topk') and hasattr(config, 'hidden_act') | |
| return MixtralSparseMoeBlock(config, input_size, widening_factor) | |
| elif config.ff_layer_type == 'mlp': | |
| assert hasattr(config, 'hidden_act') | |
| return MLP(config, input_size, widening_factor) | |
| elif config.ff_layer_type == 'gmlp': | |
| assert hasattr(config, 'hidden_act') | |
| return SimpleGMLP(config, input_size, widening_factor) | |
| else: | |
| raise ValueError( | |
| f"Unsupported ff_layer_type: {config.ff_layer_type}. Supported types are 'moe', 'mlp' and 'gmlp'.") | |
| def test_get_ff_layer(): | |
| from model.ff_layer import get_ff_layer | |
| from model.perceiver_helper import PerceiverTFConfig | |
| input_size = 32 | |
| widening_factor = 1 | |
| # Test for MoE | |
| config = PerceiverTFConfig() # or any type of PretrainedConfig() | |
| config.ff_layer_type = 'moe' | |
| config.moe_num_experts = 4 | |
| config.moe_topk = 2 | |
| config.hidden_act = 'silu' | |
| ff_layer = get_ff_layer(config, input_size, widening_factor) | |
| x = torch.rand(2, 8, input_size) | |
| hidden_states, router_logits = ff_layer(x) | |
| print(hidden_states.shape, router_logits.shape) # (2, 8, 32), (2*8, 4) | |
| # Test for MLP | |
| config.ff_layer_type = 'mlp' | |
| config.hidden_act = 'gelu' | |
| ff_layer = get_ff_layer(config, input_size, widening_factor) | |
| hidden_states, _ = ff_layer(x) | |
| print(hidden_states.shape) # (2, 8, 32) | |
| # Test for (simple)gMLP | |
| config.ff_layer_type = 'gmlp' | |
| config.hidden_act = 'silu' | |
| ff_layer = get_ff_layer(config, input_size, widening_factor) | |
| hidden_states, _ = ff_layer(x) | |
| print(hidden_states.shape) # (2, 8, 32) | |