funmaker commited on
Commit
3a8dfed
·
verified ·
1 Parent(s): 29e0d7a

Update modeling_minicpm.py

Browse files

Modify modeling_minicpm.py to use LSE compression.

Files changed (1) hide show
  1. modeling_minicpm.py +134 -72
modeling_minicpm.py CHANGED
@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
21
  import torch
22
  import torch.nn.functional as F
23
  import torch.utils.checkpoint
24
- from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
@@ -47,7 +47,9 @@ from transformers.utils import (
47
  )
48
  from transformers.utils.import_utils import is_torch_fx_available
49
 
50
- from .configuration_minicpm import MiniCPMConfig
 
 
51
 
52
  try:
53
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -68,50 +70,28 @@ from functools import lru_cache
68
  def compressed_attention(
69
  q: torch.Tensor,
70
  k: torch.Tensor,
71
- v: torch.Tensor,
72
  kernel_size: int,
73
  kernel_stride: int,
74
  block_size: int,
75
  topk: int,
76
  cu_seqlens_q: torch.Tensor,
77
  cu_seqlens_k: torch.Tensor,
 
78
  max_seqlen_q: int,
79
  max_seqlen_k: int,
80
  sm_scale: float = None,
81
  init_blocks: int = 1,
82
  local_blocks: int = 2,
83
- cache_lens: torch.Tensor = None,
84
  ) -> Tuple[torch.Tensor, torch.Tensor]:
85
- """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
86
-
87
- Args:
88
- q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
89
- k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
90
- v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
91
- kernel_size (int): kernel size in compress_key_value
92
- kernel_stride (int): stride of compress_key_value
93
- block_size (int): key value block size for topk sparse attention.
94
- topk (int): number of blocks for each query.
95
- cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
96
- cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
97
- max_seqlen_q (int): max q len of the batch.
98
- max_seqlen_k (int): max k len of the batch.
99
- sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
100
- init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
101
- local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
102
- cache_lens (torch.Tensor, optional): shape [batch_size], used to record the cache length of each query. Defaults to None.
103
-
104
- Returns:
105
- Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
106
- """
107
  with torch.no_grad():
108
  batch_size = cu_seqlens_q.shape[0] - 1
109
 
110
  # Check if it's prefilling stage
111
  is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
112
-
113
- # prefilling stage
114
- if is_prefilling:
115
  # Calculate q_idx for each query position in each batch
116
  cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
117
  q_idx = torch.cat([
@@ -119,25 +99,24 @@ def compressed_attention(
119
  max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
120
  for i in range(batch_size)
121
  ], dim=0) # shape: [total_q_len]
122
- # decoding stage
123
- else:
124
- # Each batch has only one query (last position). Shape: [batch_size] = [total_q_len] in decoding
125
- q_idx = cache_lens // block_size
126
 
127
- # compute attention score
128
  score = infllmv2_attn_stage1(
129
  q.contiguous(),
130
  k.contiguous(),
131
- v.contiguous(),
132
  cu_seqlens_q=cu_seqlens_q,
133
  cu_seqlens_k=cu_seqlens_k,
 
134
  max_seqlen_q=max_seqlen_q,
135
  max_seqlen_k=max_seqlen_k,
136
- causal=is_prefilling)
137
- # Shape: [num_heads, total_q_len, num_blocks]
138
- score = score[:, :q_idx.shape[0], :]
139
-
140
- # Shape: [num_heads, total_q_len, num_blocks]
141
  block_score = max_pooling_1d_varlen(
142
  score.contiguous(),
143
  cu_seqlens_q,
@@ -148,7 +127,9 @@ def compressed_attention(
148
  local_blocks=local_blocks,
149
  init_blocks=init_blocks,
150
  block_size=block_size,
151
- stride=kernel_stride)
 
 
152
 
153
  # get topk
154
  topk = min(topk, block_score.shape[-1])
@@ -262,6 +243,11 @@ class InfLLMv2CacheLayer(DynamicLayer):
262
  self.no_compress_k_cache = []
263
  self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
264
  self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
 
 
 
 
 
265
 
266
  def update_no_rope_key(self, key_states):
267
  if self.no_rope_keys.numel() == 0:
@@ -303,12 +289,45 @@ class InfLLMv2CacheLayer(DynamicLayer):
303
  k_chunk_list.append(None)
304
  return k_chunk_list
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  class InfLLMv2Cache(DynamicCache):
307
- def __init__(self,
308
- config,num_hidden_layers: Optional[int] = None) -> None:
309
  super().__init__(config=config)
310
  self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
311
  self._seen_tokens = 0
 
312
 
313
  def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
314
  if layer_idx == 0:
@@ -324,6 +343,12 @@ class InfLLMv2Cache(DynamicCache):
324
  def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
325
  return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
326
 
 
 
 
 
 
 
327
  def crop(self, max_length):
328
  for layer in self.layers:
329
  layer.crop(max_length)
@@ -591,7 +616,6 @@ def _unpad_one_tensor(hidden_states, attention_mask):
591
  unpadded_states = index_first_axis(reshaped_states, indices)
592
 
593
  return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
594
-
595
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
596
  """
597
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -998,7 +1022,9 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
998
  self.local_blocks = self.window_size // self.block_size # local_blocks
999
  self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
1000
  self.use_nope = self.config.sparse_config.get('use_nope', False)
 
1001
  self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
 
1002
 
1003
  def forward(
1004
  self,
@@ -1023,6 +1049,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1023
 
1024
  bsz, q_len, _ = hidden_states.size()
1025
 
 
1026
  query_states = self.q_proj(hidden_states)
1027
  key_states = self.k_proj(hidden_states)
1028
  value_states = self.v_proj(hidden_states)
@@ -1053,11 +1080,12 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1053
  key_states = key_states.transpose(1, 2)
1054
  value_states = value_states.transpose(1, 2)
1055
  if self.use_nope:
1056
- key_states_no_rope = past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
1057
  no_rope_param = {
1058
  'key_states_no_rope': key_states_no_rope,
1059
  'query_states_no_rope': query_states_no_rope,
1060
  }
 
1061
  else:
1062
  no_rope_param = None
1063
 
@@ -1103,16 +1131,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1103
  return attn_output, attn_weights, past_key_value
1104
 
1105
  def _sparse_attention_forward(
1106
- self,
1107
- query_states,
1108
- key_states,
1109
- value_states,
1110
- attention_mask,
1111
- query_length,
1112
- dropout=0.0,
1113
- softmax_scale=None,
1114
- no_rope_param=None,
1115
- past_key_value=None):
1116
  """
1117
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1118
  first unpad the input, then computes the attention scores and pad the final attention scores.
@@ -1142,15 +1162,17 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1142
  batch_size = query_states.shape[0]
1143
  # assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1144
  if past_key_value!=None:
1145
- compressed_k, compressed_cu_seqlens = self.get_compress_k(
1146
  key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
1147
  attention_mask=attention_mask,
1148
- past_key_value=past_key_value)
 
 
1149
 
1150
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1151
  query_states, key_states, value_states, attention_mask, query_length
1152
  )
1153
-
1154
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1155
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1156
  if no_rope_param != None:
@@ -1161,7 +1183,12 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1161
  if past_key_value==None:
1162
  # compress_k use varlen form
1163
  compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
 
 
 
 
1164
 
 
1165
  attn_output_unpad = self.sparse_forward(
1166
  query_states,
1167
  key_states,
@@ -1171,15 +1198,16 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1171
  max_seqlen_in_batch_q,
1172
  max_seqlen_in_batch_k,
1173
  no_rope_param=no_rope_param,
1174
- compressed_k=compressed_k,
1175
- compressed_cu_seqlens=compressed_cu_seqlens)
 
1176
 
1177
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
1178
  else:
1179
  raise ValueError('Need attention mask')
1180
 
1181
  return attn_output
1182
-
1183
  def get_compress_k(self, key_states, attention_mask, past_key_value):
1184
  """
1185
  Get compressed key states and corresponding cumulative sequence lengths.
@@ -1191,34 +1219,51 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1191
  no_rope_param: Optional parameter containing key states without rope
1192
 
1193
  Returns:
1194
- Tuple of (compressed_k, compressed_cu_seqlens)
1195
  """
 
1196
  # Check if this is prefilling or initial compression condition
 
1197
  is_prefilling = (
1198
  key_states.shape[1] >= self.dense_len and
1199
  (
1200
  not past_key_value.layers[self.layer_idx].compress_k_cache
1201
  )
1202
  )
1203
-
1204
  if is_prefilling:
1205
  unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
1206
  # Compress the keys
1207
  compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
1208
-
 
1209
  past_key_value.update_compress_k(
1210
  compressed_k, self.layer_idx, compressed_cu_seqlens)
1211
-
 
 
1212
  no_compress_k_list = []
1213
  # Compute and update no_compress_k
1214
  for i in range(len(compressed_cu_seqlens)-1):
1215
  no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
1216
-
1217
  no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
1218
 
1219
  past_key_value.update_no_compress_k(
1220
  no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
1221
  kernel_size=self.kernel_size)
 
 
 
 
 
 
 
 
 
 
 
 
1222
  else:
1223
  # Decode case: incremental update
1224
  batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
@@ -1233,16 +1278,32 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1233
  kernel_size=self.kernel_size)
1234
  new_compressed_k_list = []
1235
  for no_compress_k in no_compress_k_list:
 
1236
  if no_compress_k is not None:
1237
  # We have enough tokens to compress
1238
  new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
 
1239
  new_compressed_k_list.append(new_compressed_k)
1240
  else:
1241
  new_compressed_k_list.append(None)
1242
  compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
1243
-
1244
- return compressed_k, compressed_cu_seqlens
1245
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1246
  def sparse_forward(self,
1247
  query_layer,
1248
  key_layer,
@@ -1252,8 +1313,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1252
  max_seqlen_in_batch_q,
1253
  max_seqlen_in_batch_k,
1254
  no_rope_param=None,
1255
- compressed_k=None,
1256
- compressed_cu_seqlens=None):
1257
  compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1258
  cache_lens = None
1259
  if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
@@ -1263,13 +1324,14 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1263
  topk_idx = compressed_attention(
1264
  query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1265
  compressed_k,
1266
- compressed_k.clone(),
1267
  self.kernel_size,
1268
  self.kernel_stride,
1269
  self.block_size,
1270
  self.topk,
1271
  cu_seqlens_q,
1272
  compressed_cu_seqlens,
 
1273
  max_seqlen_in_batch_q,
1274
  compressed_seqlens.max().item(),
1275
  None,
 
21
  import torch
22
  import torch.nn.functional as F
23
  import torch.utils.checkpoint
24
+ from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
 
47
  )
48
  from transformers.utils.import_utils import is_torch_fx_available
49
 
50
+
51
+
52
+ from .configuration_minicpm import MiniCPMConfig #!一定要改
53
 
54
  try:
55
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
70
  def compressed_attention(
71
  q: torch.Tensor,
72
  k: torch.Tensor,
73
+ k2: torch.Tensor,
74
  kernel_size: int,
75
  kernel_stride: int,
76
  block_size: int,
77
  topk: int,
78
  cu_seqlens_q: torch.Tensor,
79
  cu_seqlens_k: torch.Tensor,
80
+ cu_seqlens_k2: torch.Tensor,
81
  max_seqlen_q: int,
82
  max_seqlen_k: int,
83
  sm_scale: float = None,
84
  init_blocks: int = 1,
85
  local_blocks: int = 2,
86
+ cache_lens=None,
87
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with torch.no_grad():
89
  batch_size = cu_seqlens_q.shape[0] - 1
90
 
91
  # Check if it's prefilling stage
92
  is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
93
+
94
+ if is_prefilling: # prefilling stage
 
95
  # Calculate q_idx for each query position in each batch
96
  cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
97
  q_idx = torch.cat([
 
99
  max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
100
  for i in range(batch_size)
101
  ], dim=0) # shape: [total_q_len]
102
+ else: # decoding stage
103
+ # Each batch has only one query (last position)
104
+ q_idx = cache_lens // block_size # shape: [batch_size] = [total_q_len] in decoding
 
105
 
106
+ # 计算attention score
107
  score = infllmv2_attn_stage1(
108
  q.contiguous(),
109
  k.contiguous(),
110
+ k2.contiguous(),
111
  cu_seqlens_q=cu_seqlens_q,
112
  cu_seqlens_k=cu_seqlens_k,
113
+ cu_seqlens_v=cu_seqlens_k2,
114
  max_seqlen_q=max_seqlen_q,
115
  max_seqlen_k=max_seqlen_k,
116
+ causal=is_prefilling
117
+ )
118
+ score = score[:, :q_idx.shape[0], :] # [num_heads, total_q_len, num_blocks]
119
+
 
120
  block_score = max_pooling_1d_varlen(
121
  score.contiguous(),
122
  cu_seqlens_q,
 
127
  local_blocks=local_blocks,
128
  init_blocks=init_blocks,
129
  block_size=block_size,
130
+ stride=kernel_stride
131
+ ) # shape: [num_heads, total_q_len, num_blocks]
132
+
133
 
134
  # get topk
135
  topk = min(topk, block_score.shape[-1])
 
243
  self.no_compress_k_cache = []
244
  self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
245
  self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
246
+ # Add support for compress_k2
247
+ self.compress_k2_cache = []
248
+ self.cached_compressed_cu_seqlens2 = torch.tensor([], dtype=torch.int32)
249
+ self.compress_k2_cache_varlen = torch.tensor([], dtype=torch.float32)
250
+ self.no_compress_k2_cache = []
251
 
252
  def update_no_rope_key(self, key_states):
253
  if self.no_rope_keys.numel() == 0:
 
289
  k_chunk_list.append(None)
290
  return k_chunk_list
291
 
292
+ def update_compress_k2(self, key_states, cu_seqlens=None):
293
+ if len(self.compress_k2_cache) == 0:
294
+ if cu_seqlens is not None:
295
+ self.cached_compressed_cu_seqlens2 = cu_seqlens.clone()
296
+ self.compress_k2_cache_varlen = key_states
297
+ split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
298
+ self.compress_k2_cache = list(torch.split(key_states, split_sizes))
299
+ else:
300
+ for index, k in enumerate(key_states):
301
+ if k is not None:
302
+ self.compress_k2_cache[index] = torch.cat([self.compress_k2_cache[index], k], dim=0)
303
+ new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k2_cache], dtype=torch.int32)
304
+ new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32)
305
+
306
+ self.compress_k2_cache_varlen = torch.cat(self.compress_k2_cache, dim=0)
307
+ self.cached_compressed_cu_seqlens2 = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k2_cache_varlen.device)
308
+ return self.compress_k2_cache_varlen, self.cached_compressed_cu_seqlens2
309
+
310
+ def update_no_compress_k2(self, key_states, kernel_size=128, kernel_stride=64):
311
+ k_chunk_list = []
312
+ for index, k in enumerate(key_states):
313
+ if len(self.no_compress_k2_cache) <= index:
314
+ self.no_compress_k2_cache.append(k)
315
+ else:
316
+ self.no_compress_k2_cache[index] = torch.cat([self.no_compress_k2_cache[index], k], dim=0)
317
+ current_len = self.no_compress_k2_cache[index].shape[0]
318
+ if current_len >= kernel_size:
319
+ k_chunk_list.append(self.no_compress_k2_cache[index][:kernel_size])
320
+ self.no_compress_k2_cache[index] = self.no_compress_k2_cache[index][kernel_stride:]
321
+ else:
322
+ k_chunk_list.append(None)
323
+ return k_chunk_list
324
+
325
  class InfLLMv2Cache(DynamicCache):
326
+ def __init__(self, config,num_hidden_layers: Optional[int] = None) -> None:
 
327
  super().__init__(config=config)
328
  self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
329
  self._seen_tokens = 0
330
+
331
 
332
  def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
333
  if layer_idx == 0:
 
343
  def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
344
  return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
345
 
346
+ def update_compress_k2(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None):
347
+ return self.layers[layer_idx].update_compress_k2(key_states, cu_seqlens)
348
+
349
+ def update_no_compress_k2(self, key_states, layer_idx, kernel_size=128, kernel_stride=64, cache_kwargs=None):
350
+ return self.layers[layer_idx].update_no_compress_k2(key_states, kernel_size, kernel_stride)
351
+
352
  def crop(self, max_length):
353
  for layer in self.layers:
354
  layer.crop(max_length)
 
616
  unpadded_states = index_first_axis(reshaped_states, indices)
617
 
618
  return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
 
619
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
620
  """
621
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
1022
  self.local_blocks = self.window_size // self.block_size # local_blocks
1023
  self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
1024
  self.use_nope = self.config.sparse_config.get('use_nope', False)
1025
+
1026
  self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
1027
+ self.compress_k2 = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size*4, kernel_stride=self.kernel_stride*4)
1028
 
1029
  def forward(
1030
  self,
 
1049
 
1050
  bsz, q_len, _ = hidden_states.size()
1051
 
1052
+
1053
  query_states = self.q_proj(hidden_states)
1054
  key_states = self.k_proj(hidden_states)
1055
  value_states = self.v_proj(hidden_states)
 
1080
  key_states = key_states.transpose(1, 2)
1081
  value_states = value_states.transpose(1, 2)
1082
  if self.use_nope:
1083
+ key_states_no_rope =past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
1084
  no_rope_param = {
1085
  'key_states_no_rope': key_states_no_rope,
1086
  'query_states_no_rope': query_states_no_rope,
1087
  }
1088
+
1089
  else:
1090
  no_rope_param = None
1091
 
 
1131
  return attn_output, attn_weights, past_key_value
1132
 
1133
  def _sparse_attention_forward(
1134
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, no_rope_param=None, past_key_value=None
1135
+ ):
 
 
 
 
 
 
 
 
1136
  """
1137
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1138
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
1162
  batch_size = query_states.shape[0]
1163
  # assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1164
  if past_key_value!=None:
1165
+ compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2 = self.get_compress_k(
1166
  key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
1167
  attention_mask=attention_mask,
1168
+ past_key_value=past_key_value,
1169
+
1170
+ )
1171
 
1172
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1173
  query_states, key_states, value_states, attention_mask, query_length
1174
  )
1175
+
1176
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1177
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1178
  if no_rope_param != None:
 
1183
  if past_key_value==None:
1184
  # compress_k use varlen form
1185
  compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
1186
+ compressed_k2, compressed_cu_seqlens2 = self.compress_k2(key_states,cu_seqlens_k)
1187
+ else:
1188
+ # compressed_k and compressed_k2 already retrieved from get_compress_k above
1189
+ pass
1190
 
1191
+
1192
  attn_output_unpad = self.sparse_forward(
1193
  query_states,
1194
  key_states,
 
1198
  max_seqlen_in_batch_q,
1199
  max_seqlen_in_batch_k,
1200
  no_rope_param=no_rope_param,
1201
+ compressed_k=compressed_k, compressed_cu_seqlens=compressed_cu_seqlens,
1202
+ compressed_k2=compressed_k2, compressed_cu_seqlens2=compressed_cu_seqlens2
1203
+ )
1204
 
1205
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1206
+
1207
  else:
1208
  raise ValueError('Need attention mask')
1209
 
1210
  return attn_output
 
1211
  def get_compress_k(self, key_states, attention_mask, past_key_value):
1212
  """
1213
  Get compressed key states and corresponding cumulative sequence lengths.
 
1219
  no_rope_param: Optional parameter containing key states without rope
1220
 
1221
  Returns:
1222
+ Tuple of (compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2)
1223
  """
1224
+
1225
  # Check if this is prefilling or initial compression condition
1226
+
1227
  is_prefilling = (
1228
  key_states.shape[1] >= self.dense_len and
1229
  (
1230
  not past_key_value.layers[self.layer_idx].compress_k_cache
1231
  )
1232
  )
1233
+
1234
  if is_prefilling:
1235
  unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
1236
  # Compress the keys
1237
  compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
1238
+ compressed_k2, compressed_cu_seqlens2 = self.compress_k2(unpadded_key_states, cu_seqlens)
1239
+
1240
  past_key_value.update_compress_k(
1241
  compressed_k, self.layer_idx, compressed_cu_seqlens)
1242
+ past_key_value.update_compress_k2(
1243
+ compressed_k2, self.layer_idx, compressed_cu_seqlens2)
1244
+
1245
  no_compress_k_list = []
1246
  # Compute and update no_compress_k
1247
  for i in range(len(compressed_cu_seqlens)-1):
1248
  no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
1249
+
1250
  no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
1251
 
1252
  past_key_value.update_no_compress_k(
1253
  no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
1254
  kernel_size=self.kernel_size)
1255
+
1256
+ # Also update no_compress_k2
1257
+ no_compress_k2_list = []
1258
+ for i in range(len(compressed_cu_seqlens2)-1):
1259
+ no_compress_k2_start = (compressed_cu_seqlens2[i+1]- compressed_cu_seqlens2[i]) * self.kernel_stride * 4
1260
+
1261
+ no_compress_k2_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k2_start:cu_seqlens[i+1]].clone())
1262
+
1263
+ past_key_value.update_no_compress_k2(
1264
+ no_compress_k2_list, self.layer_idx,kernel_stride=self.kernel_stride*4,
1265
+ kernel_size=self.kernel_size*4)
1266
+
1267
  else:
1268
  # Decode case: incremental update
1269
  batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
 
1278
  kernel_size=self.kernel_size)
1279
  new_compressed_k_list = []
1280
  for no_compress_k in no_compress_k_list:
1281
+
1282
  if no_compress_k is not None:
1283
  # We have enough tokens to compress
1284
  new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
1285
+
1286
  new_compressed_k_list.append(new_compressed_k)
1287
  else:
1288
  new_compressed_k_list.append(None)
1289
  compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
1290
+
1291
+ # For compress_k2, update no_compress_k2 buffer and compress when ready
1292
+ no_compress_k2_list = past_key_value.update_no_compress_k2(
1293
+ key_states_split, self.layer_idx,
1294
+ kernel_stride=self.kernel_stride*4,
1295
+ kernel_size=self.kernel_size*4)
1296
+ new_compressed_k2_list = []
1297
+ for no_compress_k2 in no_compress_k2_list:
1298
+ if no_compress_k2 is not None:
1299
+ # We have enough tokens to compress for k2
1300
+ new_compressed_k2 = no_compress_k2.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
1301
+ new_compressed_k2_list.append(new_compressed_k2)
1302
+ else:
1303
+ new_compressed_k2_list.append(None)
1304
+ compressed_k2, compressed_cu_seqlens2 = past_key_value.update_compress_k2(new_compressed_k2_list, self.layer_idx,)
1305
+
1306
+ return compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2
1307
  def sparse_forward(self,
1308
  query_layer,
1309
  key_layer,
 
1313
  max_seqlen_in_batch_q,
1314
  max_seqlen_in_batch_k,
1315
  no_rope_param=None,
1316
+ compressed_k=None, compressed_cu_seqlens=None,
1317
+ compressed_k2=None, compressed_cu_seqlens2=None):
1318
  compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1319
  cache_lens = None
1320
  if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
 
1324
  topk_idx = compressed_attention(
1325
  query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1326
  compressed_k,
1327
+ compressed_k2,
1328
  self.kernel_size,
1329
  self.kernel_stride,
1330
  self.block_size,
1331
  self.topk,
1332
  cu_seqlens_q,
1333
  compressed_cu_seqlens,
1334
+ compressed_cu_seqlens2,
1335
  max_seqlen_in_batch_q,
1336
  compressed_seqlens.max().item(),
1337
  None,