Update modeling_minicpm.py
Browse filesModify modeling_minicpm.py to use LSE compression.
- modeling_minicpm.py +134 -72
modeling_minicpm.py
CHANGED
|
@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
| 21 |
import torch
|
| 22 |
import torch.nn.functional as F
|
| 23 |
import torch.utils.checkpoint
|
| 24 |
-
from torch import
|
| 25 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
|
|
@@ -47,7 +47,9 @@ from transformers.utils import (
|
|
| 47 |
)
|
| 48 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
try:
|
| 53 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
@@ -68,50 +70,28 @@ from functools import lru_cache
|
|
| 68 |
def compressed_attention(
|
| 69 |
q: torch.Tensor,
|
| 70 |
k: torch.Tensor,
|
| 71 |
-
|
| 72 |
kernel_size: int,
|
| 73 |
kernel_stride: int,
|
| 74 |
block_size: int,
|
| 75 |
topk: int,
|
| 76 |
cu_seqlens_q: torch.Tensor,
|
| 77 |
cu_seqlens_k: torch.Tensor,
|
|
|
|
| 78 |
max_seqlen_q: int,
|
| 79 |
max_seqlen_k: int,
|
| 80 |
sm_scale: float = None,
|
| 81 |
init_blocks: int = 1,
|
| 82 |
local_blocks: int = 2,
|
| 83 |
-
cache_lens
|
| 84 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 85 |
-
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
|
| 89 |
-
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
|
| 90 |
-
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
|
| 91 |
-
kernel_size (int): kernel size in compress_key_value
|
| 92 |
-
kernel_stride (int): stride of compress_key_value
|
| 93 |
-
block_size (int): key value block size for topk sparse attention.
|
| 94 |
-
topk (int): number of blocks for each query.
|
| 95 |
-
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
|
| 96 |
-
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
|
| 97 |
-
max_seqlen_q (int): max q len of the batch.
|
| 98 |
-
max_seqlen_k (int): max k len of the batch.
|
| 99 |
-
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
|
| 100 |
-
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
|
| 101 |
-
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
|
| 102 |
-
cache_lens (torch.Tensor, optional): shape [batch_size], used to record the cache length of each query. Defaults to None.
|
| 103 |
-
|
| 104 |
-
Returns:
|
| 105 |
-
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
|
| 106 |
-
"""
|
| 107 |
with torch.no_grad():
|
| 108 |
batch_size = cu_seqlens_q.shape[0] - 1
|
| 109 |
|
| 110 |
# Check if it's prefilling stage
|
| 111 |
is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
|
| 112 |
-
|
| 113 |
-
# prefilling stage
|
| 114 |
-
if is_prefilling:
|
| 115 |
# Calculate q_idx for each query position in each batch
|
| 116 |
cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
|
| 117 |
q_idx = torch.cat([
|
|
@@ -119,25 +99,24 @@ def compressed_attention(
|
|
| 119 |
max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
|
| 120 |
for i in range(batch_size)
|
| 121 |
], dim=0) # shape: [total_q_len]
|
| 122 |
-
# decoding stage
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
q_idx = cache_lens // block_size
|
| 126 |
|
| 127 |
-
#
|
| 128 |
score = infllmv2_attn_stage1(
|
| 129 |
q.contiguous(),
|
| 130 |
k.contiguous(),
|
| 131 |
-
|
| 132 |
cu_seqlens_q=cu_seqlens_q,
|
| 133 |
cu_seqlens_k=cu_seqlens_k,
|
|
|
|
| 134 |
max_seqlen_q=max_seqlen_q,
|
| 135 |
max_seqlen_k=max_seqlen_k,
|
| 136 |
-
causal=is_prefilling
|
| 137 |
-
|
| 138 |
-
score = score[:, :q_idx.shape[0], :]
|
| 139 |
-
|
| 140 |
-
# Shape: [num_heads, total_q_len, num_blocks]
|
| 141 |
block_score = max_pooling_1d_varlen(
|
| 142 |
score.contiguous(),
|
| 143 |
cu_seqlens_q,
|
|
@@ -148,7 +127,9 @@ def compressed_attention(
|
|
| 148 |
local_blocks=local_blocks,
|
| 149 |
init_blocks=init_blocks,
|
| 150 |
block_size=block_size,
|
| 151 |
-
stride=kernel_stride
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# get topk
|
| 154 |
topk = min(topk, block_score.shape[-1])
|
|
@@ -262,6 +243,11 @@ class InfLLMv2CacheLayer(DynamicLayer):
|
|
| 262 |
self.no_compress_k_cache = []
|
| 263 |
self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
|
| 264 |
self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
def update_no_rope_key(self, key_states):
|
| 267 |
if self.no_rope_keys.numel() == 0:
|
|
@@ -303,12 +289,45 @@ class InfLLMv2CacheLayer(DynamicLayer):
|
|
| 303 |
k_chunk_list.append(None)
|
| 304 |
return k_chunk_list
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
class InfLLMv2Cache(DynamicCache):
|
| 307 |
-
def __init__(self,
|
| 308 |
-
config,num_hidden_layers: Optional[int] = None) -> None:
|
| 309 |
super().__init__(config=config)
|
| 310 |
self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
|
| 311 |
self._seen_tokens = 0
|
|
|
|
| 312 |
|
| 313 |
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
|
| 314 |
if layer_idx == 0:
|
|
@@ -324,6 +343,12 @@ class InfLLMv2Cache(DynamicCache):
|
|
| 324 |
def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
|
| 325 |
return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
def crop(self, max_length):
|
| 328 |
for layer in self.layers:
|
| 329 |
layer.crop(max_length)
|
|
@@ -591,7 +616,6 @@ def _unpad_one_tensor(hidden_states, attention_mask):
|
|
| 591 |
unpadded_states = index_first_axis(reshaped_states, indices)
|
| 592 |
|
| 593 |
return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
|
| 594 |
-
|
| 595 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 596 |
"""
|
| 597 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
@@ -998,7 +1022,9 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 998 |
self.local_blocks = self.window_size // self.block_size # local_blocks
|
| 999 |
self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
|
| 1000 |
self.use_nope = self.config.sparse_config.get('use_nope', False)
|
|
|
|
| 1001 |
self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
|
|
|
|
| 1002 |
|
| 1003 |
def forward(
|
| 1004 |
self,
|
|
@@ -1023,6 +1049,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1023 |
|
| 1024 |
bsz, q_len, _ = hidden_states.size()
|
| 1025 |
|
|
|
|
| 1026 |
query_states = self.q_proj(hidden_states)
|
| 1027 |
key_states = self.k_proj(hidden_states)
|
| 1028 |
value_states = self.v_proj(hidden_states)
|
|
@@ -1053,11 +1080,12 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1053 |
key_states = key_states.transpose(1, 2)
|
| 1054 |
value_states = value_states.transpose(1, 2)
|
| 1055 |
if self.use_nope:
|
| 1056 |
-
key_states_no_rope =
|
| 1057 |
no_rope_param = {
|
| 1058 |
'key_states_no_rope': key_states_no_rope,
|
| 1059 |
'query_states_no_rope': query_states_no_rope,
|
| 1060 |
}
|
|
|
|
| 1061 |
else:
|
| 1062 |
no_rope_param = None
|
| 1063 |
|
|
@@ -1103,16 +1131,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1103 |
return attn_output, attn_weights, past_key_value
|
| 1104 |
|
| 1105 |
def _sparse_attention_forward(
|
| 1106 |
-
self,
|
| 1107 |
-
|
| 1108 |
-
key_states,
|
| 1109 |
-
value_states,
|
| 1110 |
-
attention_mask,
|
| 1111 |
-
query_length,
|
| 1112 |
-
dropout=0.0,
|
| 1113 |
-
softmax_scale=None,
|
| 1114 |
-
no_rope_param=None,
|
| 1115 |
-
past_key_value=None):
|
| 1116 |
"""
|
| 1117 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 1118 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
@@ -1142,15 +1162,17 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1142 |
batch_size = query_states.shape[0]
|
| 1143 |
# assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
|
| 1144 |
if past_key_value!=None:
|
| 1145 |
-
compressed_k, compressed_cu_seqlens = self.get_compress_k(
|
| 1146 |
key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
|
| 1147 |
attention_mask=attention_mask,
|
| 1148 |
-
past_key_value=past_key_value
|
|
|
|
|
|
|
| 1149 |
|
| 1150 |
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 1151 |
query_states, key_states, value_states, attention_mask, query_length
|
| 1152 |
)
|
| 1153 |
-
|
| 1154 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 1155 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 1156 |
if no_rope_param != None:
|
|
@@ -1161,7 +1183,12 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1161 |
if past_key_value==None:
|
| 1162 |
# compress_k use varlen form
|
| 1163 |
compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1164 |
|
|
|
|
| 1165 |
attn_output_unpad = self.sparse_forward(
|
| 1166 |
query_states,
|
| 1167 |
key_states,
|
|
@@ -1171,15 +1198,16 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1171 |
max_seqlen_in_batch_q,
|
| 1172 |
max_seqlen_in_batch_k,
|
| 1173 |
no_rope_param=no_rope_param,
|
| 1174 |
-
compressed_k=compressed_k,
|
| 1175 |
-
|
|
|
|
| 1176 |
|
| 1177 |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
|
|
| 1178 |
else:
|
| 1179 |
raise ValueError('Need attention mask')
|
| 1180 |
|
| 1181 |
return attn_output
|
| 1182 |
-
|
| 1183 |
def get_compress_k(self, key_states, attention_mask, past_key_value):
|
| 1184 |
"""
|
| 1185 |
Get compressed key states and corresponding cumulative sequence lengths.
|
|
@@ -1191,34 +1219,51 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1191 |
no_rope_param: Optional parameter containing key states without rope
|
| 1192 |
|
| 1193 |
Returns:
|
| 1194 |
-
Tuple of (compressed_k, compressed_cu_seqlens)
|
| 1195 |
"""
|
|
|
|
| 1196 |
# Check if this is prefilling or initial compression condition
|
|
|
|
| 1197 |
is_prefilling = (
|
| 1198 |
key_states.shape[1] >= self.dense_len and
|
| 1199 |
(
|
| 1200 |
not past_key_value.layers[self.layer_idx].compress_k_cache
|
| 1201 |
)
|
| 1202 |
)
|
| 1203 |
-
|
| 1204 |
if is_prefilling:
|
| 1205 |
unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
|
| 1206 |
# Compress the keys
|
| 1207 |
compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
|
| 1208 |
-
|
|
|
|
| 1209 |
past_key_value.update_compress_k(
|
| 1210 |
compressed_k, self.layer_idx, compressed_cu_seqlens)
|
| 1211 |
-
|
|
|
|
|
|
|
| 1212 |
no_compress_k_list = []
|
| 1213 |
# Compute and update no_compress_k
|
| 1214 |
for i in range(len(compressed_cu_seqlens)-1):
|
| 1215 |
no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
|
| 1216 |
-
|
| 1217 |
no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
|
| 1218 |
|
| 1219 |
past_key_value.update_no_compress_k(
|
| 1220 |
no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
|
| 1221 |
kernel_size=self.kernel_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1222 |
else:
|
| 1223 |
# Decode case: incremental update
|
| 1224 |
batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
|
|
@@ -1233,16 +1278,32 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1233 |
kernel_size=self.kernel_size)
|
| 1234 |
new_compressed_k_list = []
|
| 1235 |
for no_compress_k in no_compress_k_list:
|
|
|
|
| 1236 |
if no_compress_k is not None:
|
| 1237 |
# We have enough tokens to compress
|
| 1238 |
new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
|
|
|
| 1239 |
new_compressed_k_list.append(new_compressed_k)
|
| 1240 |
else:
|
| 1241 |
new_compressed_k_list.append(None)
|
| 1242 |
compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
|
| 1243 |
-
|
| 1244 |
-
|
| 1245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1246 |
def sparse_forward(self,
|
| 1247 |
query_layer,
|
| 1248 |
key_layer,
|
|
@@ -1252,8 +1313,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1252 |
max_seqlen_in_batch_q,
|
| 1253 |
max_seqlen_in_batch_k,
|
| 1254 |
no_rope_param=None,
|
| 1255 |
-
compressed_k=None,
|
| 1256 |
-
|
| 1257 |
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
| 1258 |
cache_lens = None
|
| 1259 |
if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
|
|
@@ -1263,13 +1324,14 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1263 |
topk_idx = compressed_attention(
|
| 1264 |
query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
| 1265 |
compressed_k,
|
| 1266 |
-
|
| 1267 |
self.kernel_size,
|
| 1268 |
self.kernel_stride,
|
| 1269 |
self.block_size,
|
| 1270 |
self.topk,
|
| 1271 |
cu_seqlens_q,
|
| 1272 |
compressed_cu_seqlens,
|
|
|
|
| 1273 |
max_seqlen_in_batch_q,
|
| 1274 |
compressed_seqlens.max().item(),
|
| 1275 |
None,
|
|
|
|
| 21 |
import torch
|
| 22 |
import torch.nn.functional as F
|
| 23 |
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
|
|
|
|
| 47 |
)
|
| 48 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 49 |
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
from .configuration_minicpm import MiniCPMConfig #!一定要改
|
| 53 |
|
| 54 |
try:
|
| 55 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
| 70 |
def compressed_attention(
|
| 71 |
q: torch.Tensor,
|
| 72 |
k: torch.Tensor,
|
| 73 |
+
k2: torch.Tensor,
|
| 74 |
kernel_size: int,
|
| 75 |
kernel_stride: int,
|
| 76 |
block_size: int,
|
| 77 |
topk: int,
|
| 78 |
cu_seqlens_q: torch.Tensor,
|
| 79 |
cu_seqlens_k: torch.Tensor,
|
| 80 |
+
cu_seqlens_k2: torch.Tensor,
|
| 81 |
max_seqlen_q: int,
|
| 82 |
max_seqlen_k: int,
|
| 83 |
sm_scale: float = None,
|
| 84 |
init_blocks: int = 1,
|
| 85 |
local_blocks: int = 2,
|
| 86 |
+
cache_lens=None,
|
| 87 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
with torch.no_grad():
|
| 89 |
batch_size = cu_seqlens_q.shape[0] - 1
|
| 90 |
|
| 91 |
# Check if it's prefilling stage
|
| 92 |
is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
|
| 93 |
+
|
| 94 |
+
if is_prefilling: # prefilling stage
|
|
|
|
| 95 |
# Calculate q_idx for each query position in each batch
|
| 96 |
cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
|
| 97 |
q_idx = torch.cat([
|
|
|
|
| 99 |
max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
|
| 100 |
for i in range(batch_size)
|
| 101 |
], dim=0) # shape: [total_q_len]
|
| 102 |
+
else: # decoding stage
|
| 103 |
+
# Each batch has only one query (last position)
|
| 104 |
+
q_idx = cache_lens // block_size # shape: [batch_size] = [total_q_len] in decoding
|
|
|
|
| 105 |
|
| 106 |
+
# 计算attention score
|
| 107 |
score = infllmv2_attn_stage1(
|
| 108 |
q.contiguous(),
|
| 109 |
k.contiguous(),
|
| 110 |
+
k2.contiguous(),
|
| 111 |
cu_seqlens_q=cu_seqlens_q,
|
| 112 |
cu_seqlens_k=cu_seqlens_k,
|
| 113 |
+
cu_seqlens_v=cu_seqlens_k2,
|
| 114 |
max_seqlen_q=max_seqlen_q,
|
| 115 |
max_seqlen_k=max_seqlen_k,
|
| 116 |
+
causal=is_prefilling
|
| 117 |
+
)
|
| 118 |
+
score = score[:, :q_idx.shape[0], :] # [num_heads, total_q_len, num_blocks]
|
| 119 |
+
|
|
|
|
| 120 |
block_score = max_pooling_1d_varlen(
|
| 121 |
score.contiguous(),
|
| 122 |
cu_seqlens_q,
|
|
|
|
| 127 |
local_blocks=local_blocks,
|
| 128 |
init_blocks=init_blocks,
|
| 129 |
block_size=block_size,
|
| 130 |
+
stride=kernel_stride
|
| 131 |
+
) # shape: [num_heads, total_q_len, num_blocks]
|
| 132 |
+
|
| 133 |
|
| 134 |
# get topk
|
| 135 |
topk = min(topk, block_score.shape[-1])
|
|
|
|
| 243 |
self.no_compress_k_cache = []
|
| 244 |
self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
|
| 245 |
self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
|
| 246 |
+
# Add support for compress_k2
|
| 247 |
+
self.compress_k2_cache = []
|
| 248 |
+
self.cached_compressed_cu_seqlens2 = torch.tensor([], dtype=torch.int32)
|
| 249 |
+
self.compress_k2_cache_varlen = torch.tensor([], dtype=torch.float32)
|
| 250 |
+
self.no_compress_k2_cache = []
|
| 251 |
|
| 252 |
def update_no_rope_key(self, key_states):
|
| 253 |
if self.no_rope_keys.numel() == 0:
|
|
|
|
| 289 |
k_chunk_list.append(None)
|
| 290 |
return k_chunk_list
|
| 291 |
|
| 292 |
+
def update_compress_k2(self, key_states, cu_seqlens=None):
|
| 293 |
+
if len(self.compress_k2_cache) == 0:
|
| 294 |
+
if cu_seqlens is not None:
|
| 295 |
+
self.cached_compressed_cu_seqlens2 = cu_seqlens.clone()
|
| 296 |
+
self.compress_k2_cache_varlen = key_states
|
| 297 |
+
split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
| 298 |
+
self.compress_k2_cache = list(torch.split(key_states, split_sizes))
|
| 299 |
+
else:
|
| 300 |
+
for index, k in enumerate(key_states):
|
| 301 |
+
if k is not None:
|
| 302 |
+
self.compress_k2_cache[index] = torch.cat([self.compress_k2_cache[index], k], dim=0)
|
| 303 |
+
new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k2_cache], dtype=torch.int32)
|
| 304 |
+
new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32)
|
| 305 |
+
|
| 306 |
+
self.compress_k2_cache_varlen = torch.cat(self.compress_k2_cache, dim=0)
|
| 307 |
+
self.cached_compressed_cu_seqlens2 = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k2_cache_varlen.device)
|
| 308 |
+
return self.compress_k2_cache_varlen, self.cached_compressed_cu_seqlens2
|
| 309 |
+
|
| 310 |
+
def update_no_compress_k2(self, key_states, kernel_size=128, kernel_stride=64):
|
| 311 |
+
k_chunk_list = []
|
| 312 |
+
for index, k in enumerate(key_states):
|
| 313 |
+
if len(self.no_compress_k2_cache) <= index:
|
| 314 |
+
self.no_compress_k2_cache.append(k)
|
| 315 |
+
else:
|
| 316 |
+
self.no_compress_k2_cache[index] = torch.cat([self.no_compress_k2_cache[index], k], dim=0)
|
| 317 |
+
current_len = self.no_compress_k2_cache[index].shape[0]
|
| 318 |
+
if current_len >= kernel_size:
|
| 319 |
+
k_chunk_list.append(self.no_compress_k2_cache[index][:kernel_size])
|
| 320 |
+
self.no_compress_k2_cache[index] = self.no_compress_k2_cache[index][kernel_stride:]
|
| 321 |
+
else:
|
| 322 |
+
k_chunk_list.append(None)
|
| 323 |
+
return k_chunk_list
|
| 324 |
+
|
| 325 |
class InfLLMv2Cache(DynamicCache):
|
| 326 |
+
def __init__(self, config,num_hidden_layers: Optional[int] = None) -> None:
|
|
|
|
| 327 |
super().__init__(config=config)
|
| 328 |
self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
|
| 329 |
self._seen_tokens = 0
|
| 330 |
+
|
| 331 |
|
| 332 |
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
|
| 333 |
if layer_idx == 0:
|
|
|
|
| 343 |
def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
|
| 344 |
return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
|
| 345 |
|
| 346 |
+
def update_compress_k2(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None):
|
| 347 |
+
return self.layers[layer_idx].update_compress_k2(key_states, cu_seqlens)
|
| 348 |
+
|
| 349 |
+
def update_no_compress_k2(self, key_states, layer_idx, kernel_size=128, kernel_stride=64, cache_kwargs=None):
|
| 350 |
+
return self.layers[layer_idx].update_no_compress_k2(key_states, kernel_size, kernel_stride)
|
| 351 |
+
|
| 352 |
def crop(self, max_length):
|
| 353 |
for layer in self.layers:
|
| 354 |
layer.crop(max_length)
|
|
|
|
| 616 |
unpadded_states = index_first_axis(reshaped_states, indices)
|
| 617 |
|
| 618 |
return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
|
|
|
|
| 619 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 620 |
"""
|
| 621 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
|
|
| 1022 |
self.local_blocks = self.window_size // self.block_size # local_blocks
|
| 1023 |
self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
|
| 1024 |
self.use_nope = self.config.sparse_config.get('use_nope', False)
|
| 1025 |
+
|
| 1026 |
self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
|
| 1027 |
+
self.compress_k2 = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size*4, kernel_stride=self.kernel_stride*4)
|
| 1028 |
|
| 1029 |
def forward(
|
| 1030 |
self,
|
|
|
|
| 1049 |
|
| 1050 |
bsz, q_len, _ = hidden_states.size()
|
| 1051 |
|
| 1052 |
+
|
| 1053 |
query_states = self.q_proj(hidden_states)
|
| 1054 |
key_states = self.k_proj(hidden_states)
|
| 1055 |
value_states = self.v_proj(hidden_states)
|
|
|
|
| 1080 |
key_states = key_states.transpose(1, 2)
|
| 1081 |
value_states = value_states.transpose(1, 2)
|
| 1082 |
if self.use_nope:
|
| 1083 |
+
key_states_no_rope =past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
|
| 1084 |
no_rope_param = {
|
| 1085 |
'key_states_no_rope': key_states_no_rope,
|
| 1086 |
'query_states_no_rope': query_states_no_rope,
|
| 1087 |
}
|
| 1088 |
+
|
| 1089 |
else:
|
| 1090 |
no_rope_param = None
|
| 1091 |
|
|
|
|
| 1131 |
return attn_output, attn_weights, past_key_value
|
| 1132 |
|
| 1133 |
def _sparse_attention_forward(
|
| 1134 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, no_rope_param=None, past_key_value=None
|
| 1135 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1136 |
"""
|
| 1137 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 1138 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
|
| 1162 |
batch_size = query_states.shape[0]
|
| 1163 |
# assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
|
| 1164 |
if past_key_value!=None:
|
| 1165 |
+
compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2 = self.get_compress_k(
|
| 1166 |
key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
|
| 1167 |
attention_mask=attention_mask,
|
| 1168 |
+
past_key_value=past_key_value,
|
| 1169 |
+
|
| 1170 |
+
)
|
| 1171 |
|
| 1172 |
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 1173 |
query_states, key_states, value_states, attention_mask, query_length
|
| 1174 |
)
|
| 1175 |
+
|
| 1176 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 1177 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 1178 |
if no_rope_param != None:
|
|
|
|
| 1183 |
if past_key_value==None:
|
| 1184 |
# compress_k use varlen form
|
| 1185 |
compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
|
| 1186 |
+
compressed_k2, compressed_cu_seqlens2 = self.compress_k2(key_states,cu_seqlens_k)
|
| 1187 |
+
else:
|
| 1188 |
+
# compressed_k and compressed_k2 already retrieved from get_compress_k above
|
| 1189 |
+
pass
|
| 1190 |
|
| 1191 |
+
|
| 1192 |
attn_output_unpad = self.sparse_forward(
|
| 1193 |
query_states,
|
| 1194 |
key_states,
|
|
|
|
| 1198 |
max_seqlen_in_batch_q,
|
| 1199 |
max_seqlen_in_batch_k,
|
| 1200 |
no_rope_param=no_rope_param,
|
| 1201 |
+
compressed_k=compressed_k, compressed_cu_seqlens=compressed_cu_seqlens,
|
| 1202 |
+
compressed_k2=compressed_k2, compressed_cu_seqlens2=compressed_cu_seqlens2
|
| 1203 |
+
)
|
| 1204 |
|
| 1205 |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 1206 |
+
|
| 1207 |
else:
|
| 1208 |
raise ValueError('Need attention mask')
|
| 1209 |
|
| 1210 |
return attn_output
|
|
|
|
| 1211 |
def get_compress_k(self, key_states, attention_mask, past_key_value):
|
| 1212 |
"""
|
| 1213 |
Get compressed key states and corresponding cumulative sequence lengths.
|
|
|
|
| 1219 |
no_rope_param: Optional parameter containing key states without rope
|
| 1220 |
|
| 1221 |
Returns:
|
| 1222 |
+
Tuple of (compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2)
|
| 1223 |
"""
|
| 1224 |
+
|
| 1225 |
# Check if this is prefilling or initial compression condition
|
| 1226 |
+
|
| 1227 |
is_prefilling = (
|
| 1228 |
key_states.shape[1] >= self.dense_len and
|
| 1229 |
(
|
| 1230 |
not past_key_value.layers[self.layer_idx].compress_k_cache
|
| 1231 |
)
|
| 1232 |
)
|
| 1233 |
+
|
| 1234 |
if is_prefilling:
|
| 1235 |
unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
|
| 1236 |
# Compress the keys
|
| 1237 |
compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
|
| 1238 |
+
compressed_k2, compressed_cu_seqlens2 = self.compress_k2(unpadded_key_states, cu_seqlens)
|
| 1239 |
+
|
| 1240 |
past_key_value.update_compress_k(
|
| 1241 |
compressed_k, self.layer_idx, compressed_cu_seqlens)
|
| 1242 |
+
past_key_value.update_compress_k2(
|
| 1243 |
+
compressed_k2, self.layer_idx, compressed_cu_seqlens2)
|
| 1244 |
+
|
| 1245 |
no_compress_k_list = []
|
| 1246 |
# Compute and update no_compress_k
|
| 1247 |
for i in range(len(compressed_cu_seqlens)-1):
|
| 1248 |
no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
|
| 1249 |
+
|
| 1250 |
no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
|
| 1251 |
|
| 1252 |
past_key_value.update_no_compress_k(
|
| 1253 |
no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
|
| 1254 |
kernel_size=self.kernel_size)
|
| 1255 |
+
|
| 1256 |
+
# Also update no_compress_k2
|
| 1257 |
+
no_compress_k2_list = []
|
| 1258 |
+
for i in range(len(compressed_cu_seqlens2)-1):
|
| 1259 |
+
no_compress_k2_start = (compressed_cu_seqlens2[i+1]- compressed_cu_seqlens2[i]) * self.kernel_stride * 4
|
| 1260 |
+
|
| 1261 |
+
no_compress_k2_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k2_start:cu_seqlens[i+1]].clone())
|
| 1262 |
+
|
| 1263 |
+
past_key_value.update_no_compress_k2(
|
| 1264 |
+
no_compress_k2_list, self.layer_idx,kernel_stride=self.kernel_stride*4,
|
| 1265 |
+
kernel_size=self.kernel_size*4)
|
| 1266 |
+
|
| 1267 |
else:
|
| 1268 |
# Decode case: incremental update
|
| 1269 |
batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
|
|
|
|
| 1278 |
kernel_size=self.kernel_size)
|
| 1279 |
new_compressed_k_list = []
|
| 1280 |
for no_compress_k in no_compress_k_list:
|
| 1281 |
+
|
| 1282 |
if no_compress_k is not None:
|
| 1283 |
# We have enough tokens to compress
|
| 1284 |
new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
| 1285 |
+
|
| 1286 |
new_compressed_k_list.append(new_compressed_k)
|
| 1287 |
else:
|
| 1288 |
new_compressed_k_list.append(None)
|
| 1289 |
compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
|
| 1290 |
+
|
| 1291 |
+
# For compress_k2, update no_compress_k2 buffer and compress when ready
|
| 1292 |
+
no_compress_k2_list = past_key_value.update_no_compress_k2(
|
| 1293 |
+
key_states_split, self.layer_idx,
|
| 1294 |
+
kernel_stride=self.kernel_stride*4,
|
| 1295 |
+
kernel_size=self.kernel_size*4)
|
| 1296 |
+
new_compressed_k2_list = []
|
| 1297 |
+
for no_compress_k2 in no_compress_k2_list:
|
| 1298 |
+
if no_compress_k2 is not None:
|
| 1299 |
+
# We have enough tokens to compress for k2
|
| 1300 |
+
new_compressed_k2 = no_compress_k2.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
| 1301 |
+
new_compressed_k2_list.append(new_compressed_k2)
|
| 1302 |
+
else:
|
| 1303 |
+
new_compressed_k2_list.append(None)
|
| 1304 |
+
compressed_k2, compressed_cu_seqlens2 = past_key_value.update_compress_k2(new_compressed_k2_list, self.layer_idx,)
|
| 1305 |
+
|
| 1306 |
+
return compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2
|
| 1307 |
def sparse_forward(self,
|
| 1308 |
query_layer,
|
| 1309 |
key_layer,
|
|
|
|
| 1313 |
max_seqlen_in_batch_q,
|
| 1314 |
max_seqlen_in_batch_k,
|
| 1315 |
no_rope_param=None,
|
| 1316 |
+
compressed_k=None, compressed_cu_seqlens=None,
|
| 1317 |
+
compressed_k2=None, compressed_cu_seqlens2=None):
|
| 1318 |
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
| 1319 |
cache_lens = None
|
| 1320 |
if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
|
|
|
|
| 1324 |
topk_idx = compressed_attention(
|
| 1325 |
query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
| 1326 |
compressed_k,
|
| 1327 |
+
compressed_k2,
|
| 1328 |
self.kernel_size,
|
| 1329 |
self.kernel_stride,
|
| 1330 |
self.block_size,
|
| 1331 |
self.topk,
|
| 1332 |
cu_seqlens_q,
|
| 1333 |
compressed_cu_seqlens,
|
| 1334 |
+
compressed_cu_seqlens2,
|
| 1335 |
max_seqlen_in_batch_q,
|
| 1336 |
compressed_seqlens.max().item(),
|
| 1337 |
None,
|