Commit
·
d655135
1
Parent(s):
07a048e
Upload MixFormerSequentialForCausalLM
Browse files
modeling_mixformer_sequential.py
CHANGED
|
@@ -1,6 +1,36 @@
|
|
| 1 |
# Copyright (c) Microsoft Corporation.
|
| 2 |
# Licensed under the MIT license.
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import math
|
|
@@ -21,7 +51,8 @@ from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
|
| 21 |
@dataclass
|
| 22 |
class InferenceParams:
|
| 23 |
"""Inference parameters that are passed to the main model in order
|
| 24 |
-
to efficienly calculate and store the context during inference.
|
|
|
|
| 25 |
max_sequence_len: int
|
| 26 |
max_batch_size: int
|
| 27 |
sequence_len_offset: int = 0
|
|
@@ -50,7 +81,8 @@ class Embedding(nn.Module):
|
|
| 50 |
return hidden_states
|
| 51 |
|
| 52 |
class RotaryEmbedding(nn.Module):
|
| 53 |
-
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
|
|
|
|
| 54 |
|
| 55 |
def __init__(
|
| 56 |
self,
|
|
@@ -187,7 +219,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 187 |
|
| 188 |
def _update_kv_cache(kv, inference_params, layer_idx):
|
| 189 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
| 190 |
-
"""
|
| 191 |
# Pre-allocate memory for key-values for inference.
|
| 192 |
num_heads, head_dim = kv.shape[-2:]
|
| 193 |
if layer_idx not in inference_params.key_value_memory_dict:
|
|
@@ -281,6 +313,7 @@ class FusedMLP(nn.Module):
|
|
| 281 |
|
| 282 |
class SelfAttention(nn.Module):
|
| 283 |
"""Implement the scaled dot product attention with softmax.
|
|
|
|
| 284 |
Arguments
|
| 285 |
---------
|
| 286 |
softmax_scale: The temperature to use for the softmax attention.
|
|
@@ -329,6 +362,7 @@ class SelfAttention(nn.Module):
|
|
| 329 |
|
| 330 |
class CrossAttention(nn.Module):
|
| 331 |
"""Implement the scaled dot product attention with softmax.
|
|
|
|
| 332 |
Arguments
|
| 333 |
---------
|
| 334 |
softmax_scale: The temperature to use for the softmax attention.
|
|
@@ -412,7 +446,8 @@ def find_mha_dims(
|
|
| 412 |
|
| 413 |
|
| 414 |
class MHA(nn.Module):
|
| 415 |
-
"""Multi-head attention layer.
|
|
|
|
| 416 |
|
| 417 |
def __init__(
|
| 418 |
self,
|
|
@@ -472,7 +507,8 @@ class MHA(nn.Module):
|
|
| 472 |
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
| 473 |
|
| 474 |
def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
|
| 475 |
-
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
|
|
|
| 476 |
|
| 477 |
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 478 |
|
|
|
|
| 1 |
# Copyright (c) Microsoft Corporation.
|
| 2 |
# Licensed under the MIT license.
|
| 3 |
|
| 4 |
+
# BSD 3-Clause License
|
| 5 |
+
#
|
| 6 |
+
# Copyright (c) 2022, Tri Dao, [email protected].
|
| 7 |
+
# All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Redistribution and use in source and binary forms, with or without
|
| 10 |
+
# modification, are permitted provided that the following conditions are met:
|
| 11 |
+
#
|
| 12 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
| 13 |
+
# list of conditions and the following disclaimer.
|
| 14 |
+
#
|
| 15 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
| 16 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 17 |
+
# and/or other materials provided with the distribution.
|
| 18 |
+
#
|
| 19 |
+
# * Neither the name of the copyright holder nor the names of its
|
| 20 |
+
# contributors may be used to endorse or promote products derived from
|
| 21 |
+
# this software without specific prior written permission.
|
| 22 |
+
#
|
| 23 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 24 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 25 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 26 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 27 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 28 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 29 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 30 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 31 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 32 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 33 |
+
|
| 34 |
from __future__ import annotations
|
| 35 |
|
| 36 |
import math
|
|
|
|
| 51 |
@dataclass
|
| 52 |
class InferenceParams:
|
| 53 |
"""Inference parameters that are passed to the main model in order
|
| 54 |
+
to efficienly calculate and store the context during inference.
|
| 55 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
| 56 |
max_sequence_len: int
|
| 57 |
max_batch_size: int
|
| 58 |
sequence_len_offset: int = 0
|
|
|
|
| 81 |
return hidden_states
|
| 82 |
|
| 83 |
class RotaryEmbedding(nn.Module):
|
| 84 |
+
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
|
| 85 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
| 86 |
|
| 87 |
def __init__(
|
| 88 |
self,
|
|
|
|
| 219 |
|
| 220 |
def _update_kv_cache(kv, inference_params, layer_idx):
|
| 221 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
| 222 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
| 223 |
# Pre-allocate memory for key-values for inference.
|
| 224 |
num_heads, head_dim = kv.shape[-2:]
|
| 225 |
if layer_idx not in inference_params.key_value_memory_dict:
|
|
|
|
| 313 |
|
| 314 |
class SelfAttention(nn.Module):
|
| 315 |
"""Implement the scaled dot product attention with softmax.
|
| 316 |
+
Adapted from https://github.com/Dao-AILab/flash-attention.
|
| 317 |
Arguments
|
| 318 |
---------
|
| 319 |
softmax_scale: The temperature to use for the softmax attention.
|
|
|
|
| 362 |
|
| 363 |
class CrossAttention(nn.Module):
|
| 364 |
"""Implement the scaled dot product attention with softmax.
|
| 365 |
+
Adapted from https://github.com/Dao-AILab/flash-attention.
|
| 366 |
Arguments
|
| 367 |
---------
|
| 368 |
softmax_scale: The temperature to use for the softmax attention.
|
|
|
|
| 446 |
|
| 447 |
|
| 448 |
class MHA(nn.Module):
|
| 449 |
+
"""Multi-head attention layer.
|
| 450 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
| 451 |
|
| 452 |
def __init__(
|
| 453 |
self,
|
|
|
|
| 507 |
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
| 508 |
|
| 509 |
def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
|
| 510 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
| 511 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
| 512 |
|
| 513 |
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 514 |
|