|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import numpy as np | 
					
						
						|  | from timm.models.layers import to_2tuple | 
					
						
						|  |  | 
					
						
						|  | class PatchEmbed_new(nn.Module): | 
					
						
						|  | """ Flexible Image to Patch Embedding | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=16): | 
					
						
						|  | super().__init__() | 
					
						
						|  | img_size = to_2tuple(img_size) | 
					
						
						|  | patch_size = to_2tuple(patch_size) | 
					
						
						|  | stride = to_2tuple(stride) | 
					
						
						|  |  | 
					
						
						|  | self.img_size = img_size | 
					
						
						|  | self.patch_size = patch_size | 
					
						
						|  |  | 
					
						
						|  | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | x = self.proj(x) | 
					
						
						|  | x = x.flatten(2).transpose(1, 2) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): | 
					
						
						|  | """ | 
					
						
						|  | grid_size: int of the grid height and width | 
					
						
						|  | return: | 
					
						
						|  | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | 
					
						
						|  | """ | 
					
						
						|  | grid_h = np.arange(grid_size[0], dtype=np.float32) | 
					
						
						|  | grid_w = np.arange(grid_size[1], dtype=np.float32) | 
					
						
						|  | grid = np.meshgrid(grid_w, grid_h) | 
					
						
						|  | grid = np.stack(grid, axis=0) | 
					
						
						|  |  | 
					
						
						|  | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) | 
					
						
						|  | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | 
					
						
						|  | if cls_token: | 
					
						
						|  | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | 
					
						
						|  | return pos_embed | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | 
					
						
						|  | assert embed_dim % 2 == 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) | 
					
						
						|  | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) | 
					
						
						|  |  | 
					
						
						|  | emb = np.concatenate([emb_h, emb_w], axis=1) | 
					
						
						|  | return emb | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | 
					
						
						|  | """ | 
					
						
						|  | embed_dim: output dimension for each position | 
					
						
						|  | pos: a list of positions to be encoded: size (M,) | 
					
						
						|  | out: (M, D) | 
					
						
						|  | """ | 
					
						
						|  | assert embed_dim % 2 == 0 | 
					
						
						|  | omega = np.arange(embed_dim // 2, dtype=np.float32) | 
					
						
						|  | omega /= embed_dim / 2.0 | 
					
						
						|  | omega = 1.0 / 10000 ** omega | 
					
						
						|  |  | 
					
						
						|  | pos = pos.reshape(-1) | 
					
						
						|  | out = np.einsum("m,d->md", pos, omega) | 
					
						
						|  |  | 
					
						
						|  | emb_sin = np.sin(out) | 
					
						
						|  | emb_cos = np.cos(out) | 
					
						
						|  |  | 
					
						
						|  | emb = np.concatenate([emb_sin, emb_cos], axis=1) | 
					
						
						|  | return emb | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FixedPositionalEncoder(nn.Module): | 
					
						
						|  | def __init__(self, pos_embed): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.positions = pos_embed | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, padding_mask): | 
					
						
						|  | return self.positions | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AltBlock(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | dim, | 
					
						
						|  | num_heads, | 
					
						
						|  | mlp_ratio=4.0, | 
					
						
						|  | qkv_bias=False, | 
					
						
						|  | qk_scale=None, | 
					
						
						|  | drop=0.0, | 
					
						
						|  | attn_drop=0.0, | 
					
						
						|  | mlp_drop=0.0, | 
					
						
						|  | post_mlp_drop=0.0, | 
					
						
						|  | drop_path=0.0, | 
					
						
						|  | act_layer=nn.GELU, | 
					
						
						|  | norm_layer=nn.LayerNorm, | 
					
						
						|  | layer_norm_first=True, | 
					
						
						|  | ffn_targets=False, | 
					
						
						|  | cosine_attention=False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.layer_norm_first = layer_norm_first | 
					
						
						|  | self.ffn_targets = ffn_targets | 
					
						
						|  |  | 
					
						
						|  | from timm.models.vision_transformer import DropPath, Mlp | 
					
						
						|  |  | 
					
						
						|  | self.norm1 = norm_layer(dim) | 
					
						
						|  | self.attn = AltAttention( | 
					
						
						|  | dim, | 
					
						
						|  | num_heads=num_heads, | 
					
						
						|  | qkv_bias=qkv_bias, | 
					
						
						|  | qk_scale=qk_scale, | 
					
						
						|  | attn_drop=attn_drop, | 
					
						
						|  | proj_drop=drop, | 
					
						
						|  | cosine_attention=cosine_attention, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | 
					
						
						|  | self.norm2 = norm_layer(dim) | 
					
						
						|  | mlp_hidden_dim = int(dim * mlp_ratio) | 
					
						
						|  | self.mlp = Mlp( | 
					
						
						|  | in_features=dim, | 
					
						
						|  | hidden_features=mlp_hidden_dim, | 
					
						
						|  | act_layer=act_layer, | 
					
						
						|  | drop=mlp_drop, | 
					
						
						|  | ) | 
					
						
						|  | self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, padding_mask=None, alibi_bias=None): | 
					
						
						|  | if self.layer_norm_first: | 
					
						
						|  | x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias)) | 
					
						
						|  | r = x = self.mlp(self.norm2(x)) | 
					
						
						|  | t = x | 
					
						
						|  | x = r + self.drop_path(self.post_mlp_dropout(x)) | 
					
						
						|  | if not self.ffn_targets: | 
					
						
						|  | t = x | 
					
						
						|  | else: | 
					
						
						|  | x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias)) | 
					
						
						|  | r = x = self.norm1(x) | 
					
						
						|  | x = self.mlp(x) | 
					
						
						|  | t = x | 
					
						
						|  | x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x))) | 
					
						
						|  | if not self.ffn_targets: | 
					
						
						|  | t = x | 
					
						
						|  |  | 
					
						
						|  | return x, t | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AltAttention(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | dim, | 
					
						
						|  | num_heads=8, | 
					
						
						|  | qkv_bias=False, | 
					
						
						|  | qk_scale=None, | 
					
						
						|  | attn_drop=0.0, | 
					
						
						|  | proj_drop=0.0, | 
					
						
						|  | cosine_attention=False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.num_heads = num_heads | 
					
						
						|  | head_dim = dim // num_heads | 
					
						
						|  | self.scale = qk_scale or head_dim ** -0.5 | 
					
						
						|  |  | 
					
						
						|  | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | 
					
						
						|  | self.attn_drop = nn.Dropout(attn_drop) | 
					
						
						|  | self.proj = nn.Linear(dim, dim) | 
					
						
						|  | self.proj_drop = nn.Dropout(proj_drop) | 
					
						
						|  |  | 
					
						
						|  | self.cosine_attention = cosine_attention | 
					
						
						|  |  | 
					
						
						|  | if cosine_attention: | 
					
						
						|  | self.logit_scale = nn.Parameter( | 
					
						
						|  | torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, padding_mask=None, alibi_bias=None): | 
					
						
						|  | B, N, C = x.shape | 
					
						
						|  | qkv = ( | 
					
						
						|  | self.qkv(x) | 
					
						
						|  | .reshape(B, N, 3, self.num_heads, C // self.num_heads) | 
					
						
						|  | .permute(2, 0, 3, 1, 4) | 
					
						
						|  | ) | 
					
						
						|  | q, k, v = ( | 
					
						
						|  | qkv[0], | 
					
						
						|  | qkv[1], | 
					
						
						|  | qkv[2], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | dtype = q.dtype | 
					
						
						|  |  | 
					
						
						|  | if self.cosine_attention: | 
					
						
						|  |  | 
					
						
						|  | attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) | 
					
						
						|  | logit_scale = torch.clamp( | 
					
						
						|  | self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01)) | 
					
						
						|  | ).exp() | 
					
						
						|  | attn = attn * logit_scale | 
					
						
						|  | else: | 
					
						
						|  | q = q * self.scale | 
					
						
						|  | attn = q @ k.transpose(-2, -1) | 
					
						
						|  |  | 
					
						
						|  | if alibi_bias is not None: | 
					
						
						|  | attn = attn.type_as(alibi_bias) | 
					
						
						|  | attn[:, : alibi_bias.size(1)] += alibi_bias | 
					
						
						|  |  | 
					
						
						|  | if padding_mask is not None and padding_mask.any(): | 
					
						
						|  | attn = attn.masked_fill( | 
					
						
						|  | padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), | 
					
						
						|  | float("-inf"), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype) | 
					
						
						|  | attn = self.attn_drop(attn) | 
					
						
						|  | x = (attn @ v).transpose(1, 2) | 
					
						
						|  | x = x.reshape(B, N, C) | 
					
						
						|  | x = self.proj(x) | 
					
						
						|  | x = self.proj_drop(x) | 
					
						
						|  | return x |