Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| ############################################################################## | |
| """Image encoder.""" | |
| import torch | |
| from torch import nn | |
| def space_to_depth(input, block_size): | |
| """Rearrange blocks of spatial data into depth.""" | |
| if input.dim() == 3: | |
| hXw, c = input.size()[1:] | |
| h = w = int(hXw**0.5) | |
| else: | |
| h, w, c = input.size()[1:] | |
| h1, w1 = h // block_size, w // block_size | |
| c1 = (block_size**2) * c | |
| input = input.reshape((-1, h1, block_size, w1, block_size, c)) | |
| return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h1, w1, c1)) | |
| def depth_to_space(input, block_size): | |
| """Rearrange blocks of depth data into spatial.""" | |
| h1, w1, c1 = input.size()[1:] | |
| h, w = h1 * block_size, w1 * block_size | |
| c = c1 // (block_size**2) | |
| input = input.reshape((-1, h1, w1, block_size, block_size, c)) | |
| return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h, w, c)) | |
| class MLP(nn.Module): | |
| """Two layers MLP.""" | |
| def __init__(self, dim, mlp_ratio=4): | |
| super(MLP, self).__init__() | |
| self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) | |
| self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) | |
| self.activation = nn.GELU() | |
| def forward(self, x): | |
| return self.fc2(self.activation(self.fc1(x))) | |
| class Attention(nn.Module): | |
| """Multihead attention.""" | |
| def __init__(self, dim, num_heads, qkv_bias=True): | |
| super(Attention, self).__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.proj = nn.Linear(dim, dim) | |
| self.rel_pos_embed = nn.Identity() | |
| def forward(self, x): | |
| qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim) | |
| qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(dim=0) | |
| attn = q @ k.transpose(-2, -1).mul(self.scale) | |
| attn = self.rel_pos_embed(attn) | |
| o = nn.functional.softmax(attn, dim=-1) @ v | |
| return self.proj(o.transpose(1, 2).flatten(2)) | |
| class Block(nn.Module): | |
| """Transformer block.""" | |
| def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True): | |
| super(Block, self).__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.mlp = MLP(dim, mlp_ratio=mlp_ratio) | |
| def forward(self, x): | |
| x = self.attn(self.norm1(x)).add_(x) | |
| return self.mlp(self.norm2(x)).add_(x) | |
| class Bottleneck(nn.Module): | |
| """The bottleneck block.""" | |
| def __init__(self, dim, expansion=2, width=None): | |
| super(Bottleneck, self).__init__() | |
| width = width or dim // expansion | |
| self.conv1 = nn.Conv2d(dim, width, 1, bias=False) | |
| self.norm1 = nn.SyncBatchNorm(width) | |
| self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False) | |
| self.norm2 = nn.SyncBatchNorm(width) | |
| self.conv3 = nn.Conv2d(width, dim, 1, bias=False) | |
| self.norm3 = nn.SyncBatchNorm(dim) | |
| self.activation = nn.GELU() | |
| def forward(self, x): | |
| shortcut = x | |
| x = self.activation(self.norm1(self.conv1(x))) | |
| x = self.activation(self.norm2(self.conv2(x))) | |
| return self.norm3(self.conv3(x)).add_(shortcut) | |
| class PatchEmbed(nn.Module): | |
| """Patch embedding layer.""" | |
| def __init__(self, dim=768, patch_size=16, bias=True): | |
| super(PatchEmbed, self).__init__() | |
| self.proj = nn.Conv2d(3, dim, patch_size, patch_size, bias=bias) | |
| def forward(self, x): | |
| return self.proj(x).flatten(2).transpose(1, 2) | |
| class PosEmbed(nn.Module): | |
| """Position embedding layer.""" | |
| def __init__(self, dim, num_patches): | |
| super(PosEmbed, self).__init__() | |
| self.dim = dim | |
| self.num_patches = num_patches | |
| self.weight = nn.Parameter(torch.zeros(num_patches, dim)) | |
| nn.init.normal_(self.weight, std=0.02) | |
| def forward(self, x): | |
| return x.add_(self.weight) | |
| class RelPosEmbed(nn.Module): | |
| """Relative position embedding layer.""" | |
| def __init__(self, num_heads, size): | |
| super(RelPosEmbed, self).__init__() | |
| self.register_buffer("index", self.get_index(size)) | |
| self.weight = nn.Parameter(torch.zeros(num_heads, (2 * size - 1) ** 2)) | |
| def get_index(size): | |
| """Return the relative index.""" | |
| grid = torch.arange(size) | |
| grid = torch.stack(torch.meshgrid(grid, grid, indexing="ij")).reshape((2, -1)) | |
| coords = grid[:, :, None] - grid[:, None, :] + (size - 1) | |
| coords[0] *= 2 * size - 1 | |
| return coords.sum(0) | |
| def get_bias(self): | |
| return self.weight[:, self.index] | |
| def forward(self, x): | |
| return x.add_(self.get_bias()) | |
| class SimpleFeaturePyramid(nn.Module): | |
| """Module to create pyramid features.""" | |
| def __init__(self, embed_dim, out_dim, patch_size=16, min_lvl=4, max_lvl=4): | |
| super(SimpleFeaturePyramid, self).__init__() | |
| self.min_lvl, self.max_lvl = min_lvl, max_lvl | |
| self.input_conv = nn.ModuleList() | |
| self.lateral_conv = nn.ModuleList() | |
| self.output_conv = nn.ModuleList() | |
| patch_lvl = dict((2**i, i) for i in range(6))[patch_size] | |
| for lvl in [min(i + 2, self.max_lvl) for i in range(4)]: | |
| if lvl == patch_lvl or lvl < self.min_lvl: | |
| self.input_conv += [nn.Identity()] | |
| elif lvl < patch_lvl: | |
| stride, layers = 2 ** (patch_lvl - lvl), [] | |
| while stride > 1: | |
| layers += [nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)] | |
| layers += [nn.SyncBatchNorm(embed_dim), nn.GELU()] if stride > 2 else [] | |
| stride /= 2 | |
| self.input_conv.append(nn.Sequential(*layers)) | |
| elif lvl > patch_lvl: | |
| stride = 2 ** (lvl - patch_lvl) | |
| self.input_conv += [nn.MaxPool2d(stride, stride)] | |
| for _ in range(min_lvl, max_lvl + 1): | |
| self.lateral_conv.append( | |
| nn.Sequential( | |
| nn.Conv2d(embed_dim, out_dim, kernel_size=1, bias=False), | |
| nn.SyncBatchNorm(out_dim), | |
| ) | |
| ) | |
| self.output_conv.append( | |
| nn.Sequential( | |
| nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False), | |
| nn.SyncBatchNorm(out_dim), | |
| ) | |
| ) | |
| def forward(self, inputs): | |
| inputs = inputs + [inputs[-1]] * (4 - len(inputs)) | |
| inputs = [conv(x) for conv, x in zip(self.input_conv, inputs)] | |
| features = inputs[self.min_lvl - 1 : self.max_lvl] | |
| laterals = [conv(x) for conv, x in zip(self.lateral_conv, features)] | |
| return [conv(x) for conv, x in zip(self.output_conv, laterals)] | |
| class ImageEncoderViT(nn.Module): | |
| """ViT image encoder.""" | |
| def __init__( | |
| self, | |
| depth, | |
| embed_dim, | |
| num_heads, | |
| mlp_ratio=4, | |
| patch_size=16, | |
| window_size=16, | |
| image_size=1024, | |
| out_dim=256, | |
| ): | |
| super(ImageEncoderViT, self).__init__() | |
| self.embed_dim = embed_dim | |
| self.image_size = image_size | |
| self.window_size = window_size or image_size // patch_size | |
| self.patch_embed = PatchEmbed(embed_dim, patch_size) | |
| self.pos_embed = PosEmbed(embed_dim, (image_size // patch_size) ** 2) | |
| self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)) | |
| for blk in self.blocks: | |
| blk.attn.rel_pos_embed = RelPosEmbed(num_heads, self.window_size) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.cross_conv = nn.ModuleList(Bottleneck(embed_dim) for _ in range(4)) | |
| self.neck = SimpleFeaturePyramid(embed_dim, out_dim, patch_size) | |
| self.cross_indices = list(range(depth // 4 - 1, depth, depth // 4)) | |
| def forward(self, x): | |
| x = self.patch_embed(x) | |
| x = self.pos_embed(x) | |
| x = space_to_depth(x, self.window_size) | |
| wmsa_shape = (-1,) + x.shape[1:] | |
| msa_shape = (-1, self.window_size**2, self.embed_dim) | |
| x = x.reshape(msa_shape) | |
| for i, blk in enumerate(self.blocks): | |
| x = blk(x) | |
| if i in self.cross_indices or i == len(self.blocks) - 1: | |
| x = self.norm(x) if i == len(self.blocks) - 1 else x | |
| x = depth_to_space(x.reshape(wmsa_shape), self.window_size) | |
| x = x.permute(0, 3, 1, 2) | |
| if i in self.cross_indices: | |
| x = self.cross_conv[self.cross_indices.index(i)](x) | |
| if i in self.cross_indices and i < len(self.blocks) - 1: | |
| x = x.permute(0, 2, 3, 1) | |
| x = space_to_depth(x, self.window_size).reshape(msa_shape) | |
| return self.neck([x]) | |