Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,665 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
"""Attention layer."""
from __future__ import annotations
from torch import Tensor, nn
from vis4d.common.logging import rank_zero_warn
from vis4d.common.typing import ArgsType
class Attention(nn.Module):
"""ViT Attention Layer.
Modified from timm (https://github.com/huggingface/pytorch-image-models).
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
"""Init attention layer.
Args:
dim (int): Input tensor's dimension.
num_heads (int, optional): Number of attention heads. Defaults to
8.
qkv_bias (bool, optional): If to add bias to qkv. Defaults to
False.
attn_drop (float, optional): Dropout rate for attention. Defaults
to 0.0.
proj_drop (float, optional): Dropout rate for projection. Defaults
to 0.0.
"""
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def __call__(self, data: Tensor) -> Tensor:
"""Applies the layer.
Args:
data (Tensor): Input tensor of shape (B, N, dim).
Returns:
Tensor: Output tensor of the same shape as input.
"""
return self._call_impl(data)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
batch_size, num_samples, dim = x.shape
qkv = (
self.qkv(x)
.reshape(
batch_size,
num_samples,
3,
self.num_heads,
dim // self.num_heads,
)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(
0
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(batch_size, num_samples, dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiheadAttention(nn.Module):
"""A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with identity connection,
and positional encoding is also passed as input.
"""
def __init__(
self,
embed_dims: int,
num_heads: int,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
dropout_layer: nn.Module | None = None,
batch_first: bool = False,
need_weights: bool = False,
**kwargs: ArgsType,
) -> None:
"""Init MultiheadAttention.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (nn.Module | None, optional): The dropout_layer used
when adding the shortcut. Defaults to None.
batch_first (bool): When it is True, Key, Query and Value are
shape of (batch, n, embed_dim), otherwise (n, batch,
embed_dim). Default to False.
need_weights (bool): Whether to return the attention weights.
If True, the output will be a tuple of (attn_output,
attn_output_weights) and not using FlashAttention. If False,
only the attn_output will be returned. Default to False.
"""
super().__init__()
self.batch_first = batch_first
self.embed_dims = embed_dims
self.num_heads = num_heads
self.need_weights = need_weights
self.attn = nn.MultiheadAttention(
embed_dims, num_heads, dropout=attn_drop, **kwargs
)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = dropout_layer or nn.Identity()
def forward(
self,
query: Tensor,
key: Tensor | None = None,
value: Tensor | None = None,
identity: Tensor | None = None,
query_pos: Tensor | None = None,
key_pos: Tensor | None = None,
attn_mask: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor:
"""Forward function for `MultiheadAttention`.
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (Tensor): The input query with shape [num_queries, bs,
embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
identity (Tensor): This tensor, with the same shape as query,
will be used for the identity link.
If None, `query` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `query`. If not None, it will
be added to `query` before forward function. Defaults to None.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Defaults to None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`. Defaults to None.
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Defaults to None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims]
if self.batch_first is False, else [bs, num_queries,
embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None and query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
rank_zero_warn(
f"Position encoding of key in {self.__class__.__name__}"
+ "is missing, and positional encodeing of query has "
+ "has different shape and cannot be usde for key. "
+ "It it is not desired, please provide key_pos."
)
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query, batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
out = self.attn(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=self.need_weights,
)
if isinstance(out, tuple):
out = out[0]
if self.batch_first:
out = out.transpose(0, 1)
return identity + self.dropout_layer(self.proj_drop(out))
|