File size: 16,992 Bytes
2b67076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import torch
# import flashinfer
import matplotlib.pyplot as plt
# from sparse_sageattn import sparse_sageattn
from einops import rearrange, repeat
from sageattention import sageattn
from spas_sage_attn import block_sparse_sage2_attn_cuda

def get_cuda_arch_versions():
    cuda_archs = []
    for i in range(torch.cuda.device_count()):
        major, minor = torch.cuda.get_device_capability(i)
        cuda_archs.append(f"sm{major}{minor}")
    return cuda_archs

from spas_sage_attn import block_sparse_sage2_attn_cuda

def sparge_mask_convert(mask: torch.Tensor, block_size: int = 128, arch="sm") -> torch.Tensor:
    assert block_size in [128, 64], "Radial Attention only supports block size of 128 or 64"
    assert mask.shape[0] == mask.shape[1], "Input mask must be square."

    if block_size == 128:
        if arch == "sm90":
            new_mask = torch.repeat_interleave(mask, 2, dim=0)
        else:
            new_mask = torch.repeat_interleave(mask, 2, dim=1)
        
    elif block_size == 64:
        if arch == "sm90":
            num_row, num_col = mask.shape
            reshaped_mask = mask.view(num_row, num_col // 2, 2)
            new_mask = torch.max(reshaped_mask, dim=2).values
        else:
            num_row, num_col = mask.shape
            reshaped_mask = mask.view(num_row // 2, 2, num_col)
            new_mask = torch.max(reshaped_mask, dim=1).values

    return new_mask

def get_indptr_from_mask(mask, query):
    # query shows the device of the indptr
    # indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
    # shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
    # The first element is always 0, and the last element is the number of blocks in the row dimension.
    # The rest of the elements are the number of blocks in each row.
    # the mask is already a block sparse mask
    indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32)
    indptr[0] = 0
    row_counts = mask.sum(dim=1).flatten()  # Ensure 1D output [num_blocks_row]
    indptr[1:] = torch.cumsum(row_counts, dim=0)
    return indptr

def get_indices_from_mask(mask, query):
    # indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
    # shape `(nnz,),` where `nnz` is the number of non-zero blocks.
    # The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
    nonzero_indices = torch.nonzero(mask)
    indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device)
    return indices

def shrinkMaskStrict(mask, block_size=128):
    seqlen = mask.shape[0]
    block_num = seqlen // block_size
    mask = mask[:block_num * block_size, :block_num * block_size].view(block_num, block_size, block_num, block_size)
    col_densities = mask.sum(dim = 1) / block_size
    # we want the minimum non-zero column density in the block
    non_zero_densities = col_densities > 0
    high_density_cols = col_densities > 1/3
    frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
    block_mask = frac_high_density_cols > 0.6
    block_mask[0:0] = True
    block_mask[-1:-1] = True
    return block_mask

def pad_qkv(input_tensor, block_size=128):
    """
    Pad the input tensor to be a multiple of the block size.
    input shape: (seqlen, num_heads, hidden_dim)
    """
    seqlen, num_heads, hidden_dim = input_tensor.shape
    # Calculate the necessary padding
    padding_length = (block_size - (seqlen % block_size)) % block_size
    # Create a padded tensor with zeros
    padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype)
    # Copy the original tensor into the padded tensor
    padded_tensor[:seqlen, :, :] = input_tensor
    
    return padded_tensor

def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query):
    assert(sparse_type in ["radial"])
    dist = abs(i - j)
    group = dist.bit_length()
    threshold = 128 # hardcoded threshold for now, which is equal to block-size
    decay_length = 2 ** token_per_frame.bit_length() / 2 ** group
    if decay_length >= threshold:
        return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
    
    split_factor = int(threshold / decay_length)
    modular = dist % split_factor
    if modular == 0:
        return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
    else:
        return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)

def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
    assert(sparse_type in ["radial"])
    dist = abs(i - j)
    if model_type == "wan":
        if dist < 1:
            return token_per_frame
        if dist == 1:
            return token_per_frame // 2
    elif model_type == "hunyuan":
        if dist <= 1:
            return token_per_frame
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    group = dist.bit_length()
    decay_length = 2 ** token_per_frame.bit_length() / 2 ** group * decay_factor
    threshold = block_size
    if decay_length >= threshold:
        return decay_length
    else:
        return threshold

def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
    """
    A more memory friendly version, we generate the attention mask of each frame pair at a time,
    shrinks it, and stores it into the final result
    """
    final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool)
    token_per_frame = video_token_num // num_frame
    video_text_border = video_token_num // block_size

    col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1)
    row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1)
    final_log_mask[video_text_border:] = True
    final_log_mask[:, video_text_border:] = True
    for i in range(num_frame):
        for j in range(num_frame):
            local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
            if j == 0 and model_type == "wan": # this is attention sink
                local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
            else:
                window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
                local_mask = torch.abs(col_indices - row_indices) <= window_width
                split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query)
                local_mask = torch.logical_and(local_mask, split_mask)

            remainder_row = (i * token_per_frame) % block_size
            remainder_col = (j * token_per_frame) % block_size
            # get the padded size
            all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
            all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
            padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool)
            padded_local_mask[remainder_row:remainder_row + token_per_frame, remainder_col:remainder_col + token_per_frame] = local_mask
            # shrink the mask
            block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
            # set the block mask to the final log mask
            block_row_start = (i * token_per_frame) // block_size
            block_col_start = (j * token_per_frame) // block_size
            block_row_end = block_row_start + block_mask.shape[0]
            block_col_end = block_col_start + block_mask.shape[1]
            final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(
                final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
    print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}")
    return final_log_mask

class MaskMap:
    _log_mask = None

    def __init__(self, video_token_num=25440, num_frame=16):
        self.video_token_num = video_token_num
        self.num_frame = num_frame

    def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None):
        if MaskMap._log_mask is None:
            MaskMap._log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool)
            MaskMap._log_mask = gen_log_mask_shrinked(query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size)
        return MaskMap._log_mask

def SpargeSageAttnBackend(query, key, value, mask_map=None, video_mask=None, pre_defined_mask=None, block_size=128):
    if video_mask.all():
        # dense case
        kv_border = pre_defined_mask[0].sum() if pre_defined_mask is not None else key.shape[0]
        output_video = sageattn(
            query[:mask_map.video_token_num, :, :].unsqueeze(0),
            key[:kv_border, :, :].unsqueeze(0),
            value[:kv_border, :, :].unsqueeze(0),
            tensor_layout="NHD",
        )[0]
        
        if pre_defined_mask is not None:
            output_text = flashinfer.single_prefill_with_kv_cache(
                q=query[mask_map.video_token_num:, :, :],
                k=key[:pre_defined_mask[0].sum(), :, :],
                v=value[:pre_defined_mask[0].sum(), :, :],
                causal=False,
                return_lse=False,
            )
            return torch.cat([output_video, output_text], dim=0)
        else:
            return output_video
    
    # sparse-sageattention only supports (b, h, s, d) layout, need rearrange first
    query_hnd = rearrange(query.unsqueeze(0), "b s h d -> b h s d")
    key_hnd = rearrange(key.unsqueeze(0), "b s h d -> b h s d")
    value_hnd = rearrange(value.unsqueeze(0), "b s h d -> b h s d")
    arch = get_cuda_arch_versions()[query.device.index]
    converted_mask = repeat(sparge_mask_convert(mask=video_mask, block_size=block_size, arch=arch), "s t -> b h s t", b=query_hnd.shape[0], h=query_hnd.shape[1])
    
    converted_mask = converted_mask.to(torch.int8)
    if pre_defined_mask is None:
        # wan case
        output = block_sparse_sage2_attn_cuda(
            query_hnd[:, :, :mask_map.video_token_num, :],
            key_hnd[:, :, :mask_map.video_token_num, :],
            value_hnd[:, :, :mask_map.video_token_num, :],
            mask_id=converted_mask,
            tensor_layout="HND",
        )

        # rearrange back to (s, h, d), we know that b = 1
        output = rearrange(output, "b h s d -> s (b h) d", b=1)
        return output
    
    query_video = query_hnd[:, :, :mask_map.video_token_num, :]
    key_video = key_hnd
    value_video = value_hnd
    kv_border = (pre_defined_mask[0].sum() + 63) // 64
    converted_mask[:, :, :, kv_border:] = False
    output_video = block_sparse_sage2_attn_cuda(
        query_video,
        key_video,
        value_video,
        mask_id=converted_mask[:, :, :mask_map.video_token_num // block_size, :].contiguous(),
        tensor_layout="HND",
    )
    
    # rearrange back to (s, h, d), we know that b = 1
    output_video = rearrange(output_video, "b h s d -> s (b h) d", b=1)
    
    # gt = sparse_sageattn(
    #     query_video,
    #     key_video,
    #     value_video,
    #     mask_id=None,
    #     is_causal=False,
    #     tensor_layout="HND",
    # )[0]
    
    
    
    # import pdb; pdb.set_trace()
    
    output_text = flashinfer.single_prefill_with_kv_cache(
        q=query[mask_map.video_token_num:, :, :],
        k=key[:pre_defined_mask[0].sum(), :, :],
        v=value[:pre_defined_mask[0].sum(), :, :],
        causal=False,
        return_lse=False,
    )
    
    return torch.cat([output_video, output_text], dim=0)
    

def FlashInferBackend(query, key, value, mask_map=None, pre_defined_mask=None, bsr_wrapper=None):
    if pre_defined_mask is not None:
        video_video_o, video_video_o_lse = bsr_wrapper.run(
            query[:mask_map.video_token_num, :, :], 
            key[:mask_map.video_token_num, :, :],
            value[:mask_map.video_token_num, :, :],
            return_lse=True
        ) 
        # perform non-causal flashinfer on the text tokens
        video_text_o, video_text_o_lse = flashinfer.single_prefill_with_kv_cache(
            q=query[:mask_map.video_token_num, :, :],
            k=key[mask_map.video_token_num:, :, :],
            v=value[mask_map.video_token_num:, :, :],
            causal=False,
            return_lse=True,
            custom_mask=pre_defined_mask[:mask_map.video_token_num, mask_map.video_token_num:]
        )
        
        # merge the two results
        o_video, _ = flashinfer.merge_state(v_a=video_video_o, s_a=video_video_o_lse, v_b=video_text_o, s_b=video_text_o_lse)
        
        o_text = flashinfer.single_prefill_with_kv_cache(
            q=query[mask_map.video_token_num:, :, :],
            k=key,
            v=value,
            causal=False,
            return_lse=False,
            custom_mask=pre_defined_mask[mask_map.video_token_num:, :]
        )
        
        return torch.cat([o_video, o_text], dim=0)
    else:
        o = bsr_wrapper.run(
            query[:mask_map.video_token_num, :, :],
            key[:mask_map.video_token_num, :, :],
            value[:mask_map.video_token_num, :, :]
        )
        return o

def RadialAttention(query, key, value, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_type=None, pre_defined_mask=None, use_sage_attention=False):
    orig_seqlen, num_head, hidden_dim = query.shape

    if sparsity_type == "dense":
        video_mask = torch.ones((mask_map.video_token_num // block_size, mask_map.video_token_num // block_size), device=query.device, dtype=torch.bool)
    else:
        video_mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_type) if mask_map else None
    
    backend = "sparse_sageattn" if use_sage_attention else "flashinfer"
    
    if backend == "flashinfer":
        video_mask = video_mask[:mask_map.video_token_num // block_size, :mask_map.video_token_num // block_size]
        # perform block-sparse attention on the video tokens
        workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8)
        bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(
            workspace_buffer,
            backend="fa2",
        )
        
        indptr = get_indptr_from_mask(video_mask, query)
        indices = get_indices_from_mask(video_mask, query)
        
        bsr_wrapper.plan(
            indptr=indptr,
            indices=indices,
            M=mask_map.video_token_num,
            N=mask_map.video_token_num,
            R=block_size,
            C=block_size,
            num_qo_heads=num_head,
            num_kv_heads=num_head,
            head_dim=hidden_dim,
            q_data_type=query.dtype,
            kv_data_type=key.dtype,
            o_data_type=query.dtype,
        )
        
        return FlashInferBackend(query, key, value, mask_map, pre_defined_mask, bsr_wrapper)
    elif backend == "sparse_sageattn":
        return SpargeSageAttnBackend(query, key, value, mask_map, video_mask, pre_defined_mask, block_size=block_size)
        
if __name__ == "__main__":
    query = torch.randn(1, 2, 4, 64).cuda()
    # mask = torch.tensor([
    #     [True, False, True, False],
    #     [False, True, False, True],
    #     [True, False, False, True],
    #     [False, True, True, False]
    # ], dtype=torch.bool)
    # indices = get_indices_from_mask(mask, query)
    # indptr = get_indptr_from_mask(mask, query)
    # print("Indices: ", indices)
    # print("Indptr: ", indptr)
    video_token_num = 3840 * 30
    num_frame = 30
    token_per_frame = video_token_num / num_frame
    padded_video_token_num = ((video_token_num + 1) // 128 + 1) * 128
    print("padded: ", padded_video_token_num)
    temporal_mask = gen_log_mask_shrinked(query, padded_video_token_num, video_token_num, num_frame, sparse_type="radial", decay_factor=1, model_type="hunyuan")
    plt.figure(figsize=(10, 8), dpi=500)

    plt.imshow(temporal_mask.cpu().numpy()[:, :], cmap='hot')
    plt.colorbar()
    plt.title("Temporal Mask")

    plt.savefig("temporal_mask.png",
                dpi=300,
                bbox_inches='tight',
                pad_inches=0.1)

    plt.close()
    # save the mask tensor
    torch.save(temporal_mask, "temporal_mask.pt")