Spaces:
Paused
Paused
| import os | |
| import torch | |
| import datetime | |
| import torch.distributed as dist | |
| from typing import Any, Tuple | |
| from torch import Tensor | |
| from flash_attn.flash_attn_interface import flash_attn_varlen_func | |
| class COMM_INFO: | |
| def __init__(self): | |
| self.group = None | |
| self.sp_size = 1 | |
| self.global_rank = 0 | |
| self.rank_within_group = 0 | |
| self.group_id = 0 | |
| nccl_info = COMM_INFO() | |
| _SEQUENCE_PARALLEL_STATE = False | |
| def get_cu_seqlens(text_mask, img_len): | |
| """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len | |
| Args: | |
| text_mask (torch.Tensor): the mask of text | |
| img_len (int): the length of image | |
| Returns: | |
| torch.Tensor: the calculated cu_seqlens for flash attention | |
| """ | |
| batch_size = text_mask.shape[0] | |
| text_len = text_mask.sum(dim=1) | |
| max_len = text_mask.shape[1] + img_len | |
| cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") | |
| for i in range(batch_size): | |
| s = text_len[i] + img_len | |
| s1 = i * max_len + s | |
| s2 = (i + 1) * max_len | |
| cu_seqlens[2 * i + 1] = s1 | |
| cu_seqlens[2 * i + 2] = s2 | |
| return cu_seqlens | |
| def initialize_sequence_parallel_state(sequence_parallel_size): | |
| global _SEQUENCE_PARALLEL_STATE | |
| if sequence_parallel_size > 1: | |
| _SEQUENCE_PARALLEL_STATE = True | |
| initialize_sequence_parallel_group(sequence_parallel_size) | |
| else: | |
| nccl_info.sp_size = 1 | |
| nccl_info.global_rank = int(os.getenv("RANK", "0")) | |
| nccl_info.rank_within_group = 0 | |
| nccl_info.group_id = int(os.getenv("RANK", "0")) | |
| def get_sequence_parallel_state(): | |
| return _SEQUENCE_PARALLEL_STATE | |
| def initialize_sequence_parallel_group(sequence_parallel_size): | |
| """Initialize the sequence parallel group.""" | |
| rank = int(os.getenv("RANK", "0")) | |
| world_size = int(os.getenv("WORLD_SIZE", "1")) | |
| assert ( | |
| world_size % sequence_parallel_size == 0 | |
| ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format( | |
| world_size, sequence_parallel_size) | |
| nccl_info.sp_size = sequence_parallel_size | |
| nccl_info.global_rank = rank | |
| num_sequence_parallel_groups: int = world_size // sequence_parallel_size | |
| for i in range(num_sequence_parallel_groups): | |
| ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) | |
| group = dist.new_group(ranks) | |
| if rank in ranks: | |
| nccl_info.group = group | |
| nccl_info.rank_within_group = rank - i * sequence_parallel_size | |
| nccl_info.group_id = i | |
| def initialize_distributed(seed): | |
| local_rank = int(os.getenv("RANK", 0)) | |
| world_size = int(os.getenv("WORLD_SIZE", 1)) | |
| torch.cuda.set_device(local_rank) | |
| dist.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=2**31-1), world_size=world_size, rank=local_rank) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| initialize_sequence_parallel_state(world_size) | |
| def _all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor: | |
| """ | |
| all-to-all for QKV | |
| Args: | |
| input (torch.tensor): a tensor sharded along dim scatter dim | |
| scatter_idx (int): default 1 | |
| gather_idx (int): default 2 | |
| group : torch process group | |
| Returns: | |
| torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) | |
| """ | |
| assert (input.dim() == 4), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" | |
| seq_world_size = dist.get_world_size(group) | |
| if scatter_idx == 2 and gather_idx == 1: | |
| # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) | |
| bs, shard_seqlen, hc, hs = input.shape | |
| seqlen = shard_seqlen * seq_world_size | |
| shard_hc = hc // seq_world_size | |
| # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! | |
| # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) | |
| input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous()) | |
| output = torch.empty_like(input_t) | |
| # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single | |
| # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head | |
| if seq_world_size > 1: | |
| dist.all_to_all_single(output, input_t, group=group) | |
| torch.cuda.synchronize() | |
| else: | |
| output = input_t | |
| # if scattering the seq-dim, transpose the heads back to the original dimension | |
| output = output.reshape(seqlen, bs, shard_hc, hs) | |
| # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) | |
| output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) | |
| return output | |
| elif scatter_idx == 1 and gather_idx == 2: | |
| # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) | |
| bs, seqlen, shard_hc, hs = input.shape | |
| hc = shard_hc * seq_world_size | |
| shard_seqlen = seqlen // seq_world_size | |
| seq_world_size = dist.get_world_size(group) | |
| # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! | |
| # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) | |
| input_t = (input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, | |
| hs).transpose(0, | |
| 3).transpose(0, | |
| 1).contiguous().reshape(seq_world_size, shard_hc, | |
| shard_seqlen, bs, hs)) | |
| output = torch.empty_like(input_t) | |
| # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single | |
| # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head | |
| if seq_world_size > 1: | |
| dist.all_to_all_single(output, input_t, group=group) | |
| torch.cuda.synchronize() | |
| else: | |
| output = input_t | |
| # if scattering the seq-dim, transpose the heads back to the original dimension | |
| output = output.reshape(hc, shard_seqlen, bs, hs) | |
| # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) | |
| output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) | |
| return output | |
| else: | |
| raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") | |
| class SeqAllToAll4D(torch.autograd.Function): | |
| def forward( | |
| ctx: Any, | |
| group: dist.ProcessGroup, | |
| input: Tensor, | |
| scatter_idx: int, | |
| gather_idx: int, | |
| ) -> Tensor: | |
| ctx.group = group | |
| ctx.scatter_idx = scatter_idx | |
| ctx.gather_idx = gather_idx | |
| return _all_to_all_4D(input, scatter_idx, gather_idx, group=group) | |
| def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: | |
| return ( | |
| None, | |
| SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), | |
| None, | |
| None, | |
| ) | |
| def all_to_all_4D( | |
| input_: torch.Tensor, | |
| scatter_dim: int = 2, | |
| gather_dim: int = 1, | |
| ): | |
| return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim) | |
| def _all_to_all( | |
| input_: torch.Tensor, | |
| world_size: int, | |
| group: dist.ProcessGroup, | |
| scatter_dim: int, | |
| gather_dim: int, | |
| ): | |
| input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] | |
| output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] | |
| dist.all_to_all(output_list, input_list, group=group) | |
| return torch.cat(output_list, dim=gather_dim).contiguous() | |
| class _AllToAll(torch.autograd.Function): | |
| """All-to-all communication. | |
| Args: | |
| input_: input matrix | |
| process_group: communication group | |
| scatter_dim: scatter dimension | |
| gather_dim: gather dimension | |
| """ | |
| def forward(ctx, input_, process_group, scatter_dim, gather_dim): | |
| ctx.process_group = process_group | |
| ctx.scatter_dim = scatter_dim | |
| ctx.gather_dim = gather_dim | |
| ctx.world_size = dist.get_world_size(process_group) | |
| output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim) | |
| return output | |
| def backward(ctx, grad_output): | |
| grad_output = _all_to_all( | |
| grad_output, | |
| ctx.world_size, | |
| ctx.process_group, | |
| ctx.gather_dim, | |
| ctx.scatter_dim, | |
| ) | |
| return ( | |
| grad_output, | |
| None, | |
| None, | |
| None, | |
| ) | |
| def all_to_all( | |
| input_: torch.Tensor, | |
| scatter_dim: int = 2, | |
| gather_dim: int = 1, | |
| ): | |
| return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim) | |
| class _AllGather(torch.autograd.Function): | |
| """All-gather communication with autograd support. | |
| Args: | |
| input_: input tensor | |
| dim: dimension along which to concatenate | |
| """ | |
| def forward(ctx, input_, dim): | |
| ctx.dim = dim | |
| world_size = nccl_info.sp_size | |
| group = nccl_info.group | |
| input_size = list(input_.size()) | |
| ctx.input_size = input_size[dim] | |
| tensor_list = [torch.empty_like(input_) for _ in range(world_size)] | |
| input_ = input_.contiguous() | |
| dist.all_gather(tensor_list, input_, group=group) | |
| output = torch.cat(tensor_list, dim=dim) | |
| return output | |
| def backward(ctx, grad_output): | |
| world_size = nccl_info.sp_size | |
| rank = nccl_info.rank_within_group | |
| dim = ctx.dim | |
| input_size = ctx.input_size | |
| sizes = [input_size] * world_size | |
| grad_input_list = torch.split(grad_output, sizes, dim=dim) | |
| grad_input = grad_input_list[rank] | |
| return grad_input, None | |
| def all_gather(input_: torch.Tensor, dim: int = 1): | |
| """Performs an all-gather operation on the input tensor along the specified dimension. | |
| Args: | |
| input_ (torch.Tensor): Input tensor of shape [B, H, S, D]. | |
| dim (int, optional): Dimension along which to concatenate. Defaults to 1. | |
| Returns: | |
| torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'. | |
| """ | |
| return _AllGather.apply(input_, dim) | |
| def parallel_attention(q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,): | |
| """ | |
| img_q_len,img_kv_len: 32256 | |
| text_mask: 2x256 | |
| query: [2, 32256, 24, 128]) | |
| encoder_query: [2, 256, 24, 128] | |
| """ | |
| query, encoder_query = q | |
| key, encoder_key = k | |
| value, encoder_value = v | |
| rank = torch.distributed.get_rank() | |
| if get_sequence_parallel_state(): | |
| query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128] | |
| key = all_to_all_4D(key, scatter_dim=2, gather_dim=1) | |
| value = all_to_all_4D(value, scatter_dim=2, gather_dim=1) | |
| def shrink_head(encoder_state, dim): | |
| local_heads = encoder_state.shape[dim] // nccl_info.sp_size | |
| return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads) | |
| encoder_query = shrink_head(encoder_query, dim=2) | |
| encoder_key = shrink_head(encoder_key, dim=2) | |
| encoder_value = shrink_head(encoder_value, dim=2) | |
| sequence_length = query.size(1) # 32256 | |
| encoder_sequence_length = encoder_query.size(1) # 256 | |
| query = torch.cat([query, encoder_query], dim=1) | |
| key = torch.cat([key, encoder_key], dim=1) | |
| value = torch.cat([value, encoder_value], dim=1) | |
| bsz = query.shape[0] | |
| head = query.shape[-2] | |
| head_dim = query.shape[-1] | |
| query, key, value = [ | |
| x.view(x.shape[0] * x.shape[1], *x.shape[2:]) | |
| for x in [query, key, value] | |
| ] | |
| hidden_states = flash_attn_varlen_func( | |
| query, | |
| key, | |
| value, | |
| cu_seqlens_q, | |
| cu_seqlens_kv, | |
| max_seqlen_q, | |
| max_seqlen_kv, | |
| ) | |
| # B, S, 3, H, D | |
| hidden_states = hidden_states.view(bsz, max_seqlen_q, head, head_dim).contiguous() | |
| hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), | |
| dim=1) | |
| if get_sequence_parallel_state(): | |
| hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) | |
| encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous() | |
| hidden_states = hidden_states.to(query.dtype) | |
| encoder_hidden_states = encoder_hidden_states.to(query.dtype) | |
| attn = torch.cat([hidden_states, encoder_hidden_states], dim=1) | |
| b, s, _, _= attn.shape | |
| attn = attn.reshape(b, s, -1) | |
| return attn, None |