Mistral flash attn packing (#646)
Browse files* add mistral monkeypatch
* add arg for decoder attention masl
* fix lint for duplicate code
* make sure to update transformers too
* tweak install for e2e
* move mistral patch to conditional
.github/workflows/tests.yml
CHANGED
|
@@ -44,7 +44,7 @@ jobs:
|
|
| 44 |
|
| 45 |
- name: Install dependencies
|
| 46 |
run: |
|
| 47 |
-
pip3 install -e .
|
| 48 |
pip3 install -r requirements-tests.txt
|
| 49 |
|
| 50 |
- name: Run tests
|
|
@@ -69,8 +69,7 @@ jobs:
|
|
| 69 |
|
| 70 |
- name: Install dependencies
|
| 71 |
run: |
|
| 72 |
-
pip3 install -e .
|
| 73 |
-
pip3 install flash-attn
|
| 74 |
pip3 install -r requirements-tests.txt
|
| 75 |
|
| 76 |
- name: Run e2e tests
|
|
|
|
| 44 |
|
| 45 |
- name: Install dependencies
|
| 46 |
run: |
|
| 47 |
+
pip3 install -U -e .
|
| 48 |
pip3 install -r requirements-tests.txt
|
| 49 |
|
| 50 |
- name: Run tests
|
|
|
|
| 69 |
|
| 70 |
- name: Install dependencies
|
| 71 |
run: |
|
| 72 |
+
pip3 install -U -e .[flash-attn]
|
|
|
|
| 73 |
pip3 install -r requirements-tests.txt
|
| 74 |
|
| 75 |
- name: Run e2e tests
|
requirements.txt
CHANGED
|
@@ -4,7 +4,7 @@ torch==2.0.1
|
|
| 4 |
auto-gptq
|
| 5 |
packaging
|
| 6 |
peft @ git+https://github.com/huggingface/peft.git
|
| 7 |
-
transformers @ git+https://github.com/huggingface/transformers.git@
|
| 8 |
bitsandbytes>=0.41.1
|
| 9 |
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
| 10 |
deepspeed
|
|
|
|
| 4 |
auto-gptq
|
| 5 |
packaging
|
| 6 |
peft @ git+https://github.com/huggingface/peft.git
|
| 7 |
+
transformers @ git+https://github.com/huggingface/transformers.git@78dd120
|
| 8 |
bitsandbytes>=0.41.1
|
| 9 |
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
| 10 |
deepspeed
|
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flash attention monkey patch for mistral model"""
|
| 2 |
+
# pylint: disable=duplicate-code
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import transformers
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from torch import nn
|
| 12 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 13 |
+
from transformers.models.mistral.modeling_mistral import (
|
| 14 |
+
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
| 15 |
+
)
|
| 16 |
+
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
| 17 |
+
|
| 18 |
+
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
| 22 |
+
flash_attn_varlen_qkvpacked_func,
|
| 23 |
+
)
|
| 24 |
+
except ImportError:
|
| 25 |
+
from flash_attn.flash_attn_interface import (
|
| 26 |
+
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def replace_mistral_attn_with_flash_attn(
|
| 34 |
+
packed: Optional[bool] = False,
|
| 35 |
+
):
|
| 36 |
+
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
| 37 |
+
_prepare_decoder_attention_mask
|
| 38 |
+
)
|
| 39 |
+
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
| 40 |
+
flashattn_forward
|
| 41 |
+
)
|
| 42 |
+
if packed:
|
| 43 |
+
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
| 44 |
+
MistralDecoderLayer
|
| 45 |
+
)
|
| 46 |
+
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
| 47 |
+
mistral_model_forward
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
| 52 |
+
# requires the attention mask to be the same as the key_padding_mask
|
| 53 |
+
def _prepare_decoder_attention_mask(
|
| 54 |
+
self,
|
| 55 |
+
attention_mask,
|
| 56 |
+
input_shape,
|
| 57 |
+
inputs_embeds,
|
| 58 |
+
past_key_values_length,
|
| 59 |
+
sliding_window,
|
| 60 |
+
): # pylint: disable=unused-argument
|
| 61 |
+
# [bsz, seq_len]
|
| 62 |
+
return attention_mask
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def flashattn_forward(
|
| 66 |
+
self,
|
| 67 |
+
hidden_states: torch.Tensor,
|
| 68 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 69 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 70 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 71 |
+
output_attentions: bool = False,
|
| 72 |
+
use_cache: bool = False,
|
| 73 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 74 |
+
max_seqlen: Optional[torch.Tensor] = None,
|
| 75 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 76 |
+
bsz, q_len, _ = hidden_states.size()
|
| 77 |
+
|
| 78 |
+
query_states = self.q_proj(hidden_states)
|
| 79 |
+
key_states = self.k_proj(hidden_states)
|
| 80 |
+
value_states = self.v_proj(hidden_states)
|
| 81 |
+
|
| 82 |
+
query_states = query_states.view(
|
| 83 |
+
bsz, q_len, self.num_heads, self.head_dim
|
| 84 |
+
).transpose(1, 2)
|
| 85 |
+
key_states = key_states.view(
|
| 86 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
| 87 |
+
).transpose(1, 2)
|
| 88 |
+
value_states = value_states.view(
|
| 89 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
| 90 |
+
).transpose(1, 2)
|
| 91 |
+
|
| 92 |
+
kv_seq_len = key_states.shape[-2]
|
| 93 |
+
if past_key_value is not None:
|
| 94 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 95 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 96 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 97 |
+
query_states, key_states, cos, sin, position_ids
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if past_key_value is not None:
|
| 101 |
+
# reuse k, v, self_attention
|
| 102 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 103 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 104 |
+
|
| 105 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 106 |
+
|
| 107 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 108 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 109 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 110 |
+
|
| 111 |
+
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
| 112 |
+
# special handling using sample packing
|
| 113 |
+
qkv = torch.stack(
|
| 114 |
+
[query_states, key_states, value_states], dim=2
|
| 115 |
+
) # [bsz, nh, 3, q_len, hd]
|
| 116 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
| 117 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 118 |
+
|
| 119 |
+
output = flash_attn_varlen_qkvpacked_func(
|
| 120 |
+
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
| 121 |
+
)
|
| 122 |
+
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 123 |
+
attn_output = output
|
| 124 |
+
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
| 127 |
+
f" {attn_output.size()}"
|
| 128 |
+
)
|
| 129 |
+
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
| 130 |
+
attn_weights = None
|
| 131 |
+
else:
|
| 132 |
+
attn_weights = torch.matmul(
|
| 133 |
+
query_states, key_states.transpose(2, 3)
|
| 134 |
+
) / math.sqrt(self.head_dim)
|
| 135 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
| 138 |
+
f" {attn_weights.size()}"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
if attention_mask is not None:
|
| 142 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 143 |
+
raise ValueError(
|
| 144 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
attn_weights = attn_weights + attention_mask
|
| 148 |
+
|
| 149 |
+
# upcast attention to fp32
|
| 150 |
+
attn_weights = nn.functional.softmax(
|
| 151 |
+
attn_weights, dim=-1, dtype=torch.float32
|
| 152 |
+
).to(query_states.dtype)
|
| 153 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 154 |
+
|
| 155 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 158 |
+
f" {attn_output.size()}"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 162 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 163 |
+
|
| 164 |
+
attn_output = self.o_proj(attn_output)
|
| 165 |
+
|
| 166 |
+
if not output_attentions:
|
| 167 |
+
attn_weights = None
|
| 168 |
+
|
| 169 |
+
return attn_output, attn_weights, past_key_value
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def mistral_model_forward(
|
| 173 |
+
self,
|
| 174 |
+
input_ids: torch.LongTensor = None,
|
| 175 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 176 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 177 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 178 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 179 |
+
use_cache: Optional[bool] = None,
|
| 180 |
+
output_attentions: Optional[bool] = None,
|
| 181 |
+
output_hidden_states: Optional[bool] = None,
|
| 182 |
+
return_dict: Optional[bool] = None,
|
| 183 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 184 |
+
output_attentions = (
|
| 185 |
+
output_attentions
|
| 186 |
+
if output_attentions is not None
|
| 187 |
+
else self.config.output_attentions
|
| 188 |
+
)
|
| 189 |
+
output_hidden_states = (
|
| 190 |
+
output_hidden_states
|
| 191 |
+
if output_hidden_states is not None
|
| 192 |
+
else self.config.output_hidden_states
|
| 193 |
+
)
|
| 194 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 195 |
+
|
| 196 |
+
return_dict = (
|
| 197 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# retrieve input_ids and inputs_embeds
|
| 201 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
| 204 |
+
)
|
| 205 |
+
if input_ids is not None:
|
| 206 |
+
batch_size, seq_length = input_ids.shape
|
| 207 |
+
elif inputs_embeds is not None:
|
| 208 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
seq_length_with_past = seq_length
|
| 215 |
+
past_key_values_length = 0
|
| 216 |
+
|
| 217 |
+
if past_key_values is not None:
|
| 218 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 219 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 220 |
+
|
| 221 |
+
cu_seqlens = None
|
| 222 |
+
max_seqlen = None
|
| 223 |
+
if position_ids is None:
|
| 224 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 225 |
+
position_ids = torch.arange(
|
| 226 |
+
past_key_values_length,
|
| 227 |
+
seq_length + past_key_values_length,
|
| 228 |
+
dtype=torch.long,
|
| 229 |
+
device=device,
|
| 230 |
+
)
|
| 231 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 232 |
+
else:
|
| 233 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 234 |
+
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
| 235 |
+
cu_seqlens = cu_seqlens.squeeze()
|
| 236 |
+
|
| 237 |
+
if inputs_embeds is None:
|
| 238 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 239 |
+
# embed positions
|
| 240 |
+
if attention_mask is None:
|
| 241 |
+
attention_mask = torch.ones(
|
| 242 |
+
(batch_size, seq_length_with_past),
|
| 243 |
+
dtype=torch.bool,
|
| 244 |
+
device=inputs_embeds.device,
|
| 245 |
+
)
|
| 246 |
+
attention_mask = (
|
| 247 |
+
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
| 248 |
+
attention_mask,
|
| 249 |
+
(batch_size, seq_length),
|
| 250 |
+
inputs_embeds,
|
| 251 |
+
past_key_values_length,
|
| 252 |
+
sliding_window=self.config.sliding_window,
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
hidden_states = inputs_embeds
|
| 257 |
+
|
| 258 |
+
if self.gradient_checkpointing and self.training:
|
| 259 |
+
if use_cache:
|
| 260 |
+
transformers.logger.warning_once(
|
| 261 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 262 |
+
)
|
| 263 |
+
use_cache = False
|
| 264 |
+
|
| 265 |
+
# decoder layers
|
| 266 |
+
all_hidden_states = () if output_hidden_states else None
|
| 267 |
+
all_self_attns = () if output_attentions else None
|
| 268 |
+
next_decoder_cache = () if use_cache else None
|
| 269 |
+
|
| 270 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 271 |
+
if output_hidden_states:
|
| 272 |
+
all_hidden_states += (hidden_states,)
|
| 273 |
+
|
| 274 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 275 |
+
|
| 276 |
+
if self.gradient_checkpointing and self.training:
|
| 277 |
+
|
| 278 |
+
def create_custom_forward(module):
|
| 279 |
+
def custom_forward(*inputs):
|
| 280 |
+
# None for past_key_value
|
| 281 |
+
return module(*inputs)
|
| 282 |
+
|
| 283 |
+
return custom_forward
|
| 284 |
+
|
| 285 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 286 |
+
create_custom_forward(decoder_layer),
|
| 287 |
+
hidden_states,
|
| 288 |
+
attention_mask,
|
| 289 |
+
position_ids,
|
| 290 |
+
past_key_value,
|
| 291 |
+
output_attentions,
|
| 292 |
+
None,
|
| 293 |
+
cu_seqlens,
|
| 294 |
+
max_seqlen,
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
layer_outputs = decoder_layer(
|
| 298 |
+
hidden_states,
|
| 299 |
+
attention_mask=attention_mask,
|
| 300 |
+
position_ids=position_ids,
|
| 301 |
+
past_key_value=past_key_value,
|
| 302 |
+
output_attentions=output_attentions,
|
| 303 |
+
use_cache=use_cache,
|
| 304 |
+
cu_seqlens=cu_seqlens,
|
| 305 |
+
max_seqlen=max_seqlen,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
hidden_states = layer_outputs[0]
|
| 309 |
+
|
| 310 |
+
if use_cache:
|
| 311 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 312 |
+
|
| 313 |
+
if output_attentions:
|
| 314 |
+
all_self_attns += (layer_outputs[1],)
|
| 315 |
+
|
| 316 |
+
hidden_states = self.norm(hidden_states)
|
| 317 |
+
|
| 318 |
+
# add hidden states from the last decoder layer
|
| 319 |
+
if output_hidden_states:
|
| 320 |
+
all_hidden_states += (hidden_states,)
|
| 321 |
+
|
| 322 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 323 |
+
if not return_dict:
|
| 324 |
+
return tuple(
|
| 325 |
+
v
|
| 326 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
| 327 |
+
if v is not None
|
| 328 |
+
)
|
| 329 |
+
return BaseModelOutputWithPast(
|
| 330 |
+
last_hidden_state=hidden_states,
|
| 331 |
+
past_key_values=next_cache,
|
| 332 |
+
hidden_states=all_hidden_states,
|
| 333 |
+
attentions=all_self_attns,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
| 338 |
+
"""
|
| 339 |
+
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def forward(
|
| 343 |
+
self,
|
| 344 |
+
hidden_states: torch.Tensor,
|
| 345 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 346 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 347 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 348 |
+
output_attentions: Optional[bool] = False,
|
| 349 |
+
use_cache: Optional[bool] = False,
|
| 350 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 351 |
+
max_seqlen: Optional[torch.Tensor] = None,
|
| 352 |
+
) -> Tuple[
|
| 353 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
| 354 |
+
]:
|
| 355 |
+
"""
|
| 356 |
+
Args:
|
| 357 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 358 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 359 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 360 |
+
output_attentions (`bool`, *optional*):
|
| 361 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 362 |
+
returned tensors for more detail.
|
| 363 |
+
use_cache (`bool`, *optional*):
|
| 364 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 365 |
+
(see `past_key_values`).
|
| 366 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 367 |
+
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
residual = hidden_states
|
| 371 |
+
|
| 372 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 373 |
+
|
| 374 |
+
# Self Attention
|
| 375 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 376 |
+
hidden_states=hidden_states,
|
| 377 |
+
attention_mask=attention_mask,
|
| 378 |
+
position_ids=position_ids,
|
| 379 |
+
past_key_value=past_key_value,
|
| 380 |
+
output_attentions=output_attentions,
|
| 381 |
+
use_cache=use_cache,
|
| 382 |
+
cu_seqlens=cu_seqlens,
|
| 383 |
+
max_seqlen=max_seqlen,
|
| 384 |
+
)
|
| 385 |
+
hidden_states = residual + hidden_states
|
| 386 |
+
|
| 387 |
+
# Fully Connected
|
| 388 |
+
residual = hidden_states
|
| 389 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 390 |
+
hidden_states = self.mlp(hidden_states)
|
| 391 |
+
hidden_states = residual + hidden_states
|
| 392 |
+
|
| 393 |
+
outputs = (hidden_states,)
|
| 394 |
+
|
| 395 |
+
if output_attentions:
|
| 396 |
+
outputs += (self_attn_weights,)
|
| 397 |
+
|
| 398 |
+
if use_cache:
|
| 399 |
+
outputs += (present_key_value,)
|
| 400 |
+
|
| 401 |
+
return outputs
|
src/axolotl/utils/models.py
CHANGED
|
@@ -150,6 +150,14 @@ def load_model(
|
|
| 150 |
# Note: This might overwrite previous additional_special_tokens
|
| 151 |
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
| 154 |
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
| 155 |
replace_llama_rope_with_xpos_rope,
|
|
|
|
| 150 |
# Note: This might overwrite previous additional_special_tokens
|
| 151 |
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
| 152 |
|
| 153 |
+
if cfg.is_mistral_derived_model and cfg.flash_attention:
|
| 154 |
+
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
| 155 |
+
replace_mistral_attn_with_flash_attn,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
LOG.info("patching with flash attention")
|
| 159 |
+
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
| 160 |
+
|
| 161 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
| 162 |
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
| 163 |
replace_llama_rope_with_xpos_rope,
|