Spaces:
Runtime error
Runtime error
| import timm | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| from einops import rearrange | |
| def disabled_train(self, mode=True): | |
| """ | |
| Overwrite model.train with this function to make sure train/eval mode does not change anymore | |
| """ | |
| return self | |
| def simple_conv_and_linear_weights_init(m): | |
| if type(m) in [ | |
| nn.Conv1d, | |
| nn.Conv2d, | |
| nn.Conv3d, | |
| nn.ConvTranspose1d, | |
| nn.ConvTranspose2d, | |
| nn.ConvTranspose3d, | |
| ]: | |
| weight_shape = list(m.weight.data.size()) | |
| fan_in = np.prod(weight_shape[1:4]) | |
| fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] | |
| w_bound = np.sqrt(6.0 / (fan_in + fan_out)) | |
| m.weight.data.uniform_(-w_bound, w_bound) | |
| if m.bias is not None: | |
| m.bias.data.fill_(0) | |
| elif type(m) == nn.Linear: | |
| simple_linear_weights_init(m) | |
| def simple_linear_weights_init(m): | |
| if type(m) == nn.Linear: | |
| weight_shape = list(m.weight.data.size()) | |
| fan_in = weight_shape[1] | |
| fan_out = weight_shape[0] | |
| w_bound = np.sqrt(6.0 / (fan_in + fan_out)) | |
| m.weight.data.uniform_(-w_bound, w_bound) | |
| if m.bias is not None: | |
| m.bias.data.fill_(0) | |
| class Backbone2DWrapper(nn.Module): | |
| def __init__(self, model, tag, freeze=True): | |
| super().__init__() | |
| self.model = model | |
| self.tag = tag | |
| self.freeze = freeze | |
| if 'convnext' in tag: | |
| self.out_channels = 1024 | |
| elif 'swin' in tag: | |
| self.out_channels = 1024 | |
| elif 'vit' in tag: | |
| self.out_channels = 768 | |
| elif 'resnet' in tag: | |
| self.out_channels = 2048 | |
| else: | |
| raise NotImplementedError | |
| if freeze: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| self.eval() | |
| self.train = disabled_train | |
| def forward_normal(self, x, flat_output=False): | |
| feat = self.model.forward_features(x) | |
| if 'swin' in self.tag: | |
| feat = rearrange(feat, 'b h w c -> b c h w') | |
| if 'vit_base_32_timm_laion2b' in self.tag or 'vit_base_32_timm_openai' in self.tag: | |
| # TODO: [CLS] is prepended to the patches. | |
| feat = rearrange(feat[:, 1:], 'b (h w) c -> b c h w', h=7) | |
| if flat_output: | |
| feat = rearrange(feat, 'b c h w -> b (h w) c') | |
| return feat | |
| def forward_frozen(self, x, flat_output=False): | |
| return self.forward_normal(x, flat_output) | |
| def forward(self, x, flat_output=False): | |
| if self.freeze: | |
| return self.forward_frozen(x, flat_output) | |
| else: | |
| return self.forward_normal(x, flat_output) | |
| def convnext_base_laion2b(pretrained=False, freeze=True, **kwargs): | |
| m = timm.create_model( | |
| 'convnext_base.clip_laion2b', | |
| pretrained=pretrained | |
| ) | |
| if kwargs.get('reset_clip_s2b2'): | |
| s = m.state_dict() | |
| for i in s.keys(): | |
| if 'stages.3.blocks.2' in i and ('weight' in i or 'bias' in i): | |
| s[i].normal_() | |
| m.load_state_dict(s, strict=True) | |
| return Backbone2DWrapper(m, 'convnext_base_laion2b', freeze=freeze) | |
| class GridFeatureExtractor2D(nn.Module): | |
| def __init__(self, backbone_name='convnext_base', backbone_pretrain_dataset='laion2b', use_pretrain=True, freeze=True, pooling='avg'): | |
| super().__init__() | |
| init_func_name = '_'.join([backbone_name, backbone_pretrain_dataset]) | |
| init_func = globals().get(init_func_name) | |
| if init_func and callable(init_func): | |
| self.backbone = init_func(pretrained=use_pretrain, freeze=freeze) | |
| else: | |
| raise NotImplementedError(f"Backbone2D does not support {init_func_name}") | |
| self.pooling = pooling | |
| if self.pooling: | |
| if self.pooling == 'avg': | |
| self.pooling_layers = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(output_size=(1,1)), | |
| nn.Flatten() | |
| ) | |
| self.out_channels = self.backbone.out_channels | |
| elif self.pooling == 'conv': | |
| self.pooling_layers = nn.Sequential( | |
| nn.Conv2d(self.backbone.out_channels, 64, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 32, 1), | |
| nn.Flatten() | |
| ) | |
| self.pooling_layers.apply(simple_conv_and_linear_weights_init) | |
| self.out_channels = 32 * 7 * 7 # hardcode for 224x224 | |
| elif self.pooling in ['attn', 'attention']: | |
| self.visual_attention = nn.Sequential( | |
| nn.Conv2d(self.backbone.out_channels, self.backbone.out_channels, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(self.backbone.out_channels, self.backbone.out_channels, 1), | |
| ) | |
| self.visual_attention.apply(simple_conv_and_linear_weights_init) | |
| def _attention_pooling(x): | |
| B, C, H, W = x.size() | |
| attn = self.visual_attention(x) | |
| attn = attn.view(B, C, -1) | |
| x = x.view(B, C, -1) | |
| attn = attn.softmax(dim=-1) | |
| x = torch.einsum('b c n, b c n -> b c', x, x) | |
| return x | |
| self.pooling_layers = _attention_pooling | |
| self.out_channels = self.backbone.out_channels | |
| else: | |
| raise NotImplementedError(f"Backbone2D does not support {self.pooling} pooling") | |
| else: | |
| self.out_channels = self.backbone.out_channels | |
| def forward(self, x): | |
| if self.pooling: | |
| x = self.backbone(x, flat_output=False) | |
| x = self.pooling_layers(x).unsqueeze(1) | |
| return x | |
| else: | |
| return self.backbone(x, flat_output=True) |