Spaces:
Build error
Build error
| # Copyright (C) 2021-2022 Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| import os | |
| import torch | |
| from torch import nn | |
| import torchvision | |
| from how import layers | |
| from lit import LocalfeatureIntegrationTransformer | |
| from how.networks.how_net import HOWNet | |
| class FIReNet(HOWNet): | |
| def __init__(self, features, attention, lit, dim_reduction, meta, runtime): | |
| super().__init__(features, attention, None, dim_reduction, meta, runtime) | |
| self.lit = lit | |
| self.return_global = False | |
| def copy_excluding_dim_reduction(self): | |
| """Return a copy of this network without the dim_reduction layer""" | |
| meta = {**self.meta, "outputdim": self.meta['backbone_dim']} | |
| return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime) | |
| def copy_with_runtime(self, runtime): | |
| """Return a copy of this network with a different runtime dict""" | |
| return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime) | |
| def parameter_groups(self): | |
| """Return torch parameter groups""" | |
| layers = [self.features, self.attention, self.smoothing, self.lit] | |
| parameters = [{'params': x.parameters()} for x in layers if x is not None] | |
| if self.dim_reduction: | |
| # Do not update dimensionality reduction layer | |
| parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0}) | |
| return parameters | |
| def get_superfeatures(self, x, *, scales): | |
| """ | |
| return a list of tuple (features, attentionmpas) where each is a list containing requested scales | |
| features is a tensor BxDxNx1 | |
| attentionmaps is a tensor BxNxHxW | |
| """ | |
| feats = [] | |
| attns = [] | |
| strengths = [] | |
| for s in scales: | |
| xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False) | |
| o = self.features(xs) | |
| o, attn = self.lit(o) | |
| strength = self.attention(o) | |
| if self.smoothing: | |
| o = self.smoothing(o) | |
| if self.dim_reduction: | |
| o = self.dim_reduction(o) | |
| feats.append(o) | |
| attns.append(attn) | |
| strengths.append(strength) | |
| return feats, attns, strengths | |
| def forward(self, x): | |
| return self.get_superfeatures(x, scales=self.runtime['training_scales']) | |
| def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime): | |
| """Initialize FIRe network | |
| :param str architecture: Network backbone architecture (e.g. resnet18) | |
| :param str pretrained: url of the pretrained model (None for using random initialization) | |
| :param int skip_layer: How many layers of blocks should be skipped (from the end) | |
| :param dict dim_reduction: Options for the dimensionality reduction layer | |
| :param dict lit: Options for the lit layer | |
| :param dict runtime: Runtime options to be stored in the network | |
| :return FIRe: Initialized network | |
| """ | |
| # Take convolutional layers as features, always ends with ReLU to make last activations non-negative | |
| net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead | |
| if architecture.startswith('alexnet') or architecture.startswith('vgg'): | |
| features = list(net_in.features.children())[:-1] | |
| elif architecture.startswith('resnet'): | |
| features = list(net_in.children())[:-2] | |
| elif architecture.startswith('densenet'): | |
| features = list(net_in.features.children()) + [nn.ReLU(inplace=True)] | |
| elif architecture.startswith('squeezenet'): | |
| features = list(net_in.features.children()) | |
| else: | |
| raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture)) | |
| if skip_layer > 0: | |
| features = features[:-skip_layer] | |
| backbone_dim = 2048 // (2 ** skip_layer) | |
| att_layer = layers.attention.L2Attention() | |
| lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim) | |
| reduction_layer = None | |
| if dim_reduction: | |
| reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim']) | |
| meta = { | |
| "architecture": architecture, | |
| "backbone_dim": lit['dim'], | |
| "outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'], | |
| "corercf_size": 32 // (2 ** skip_layer), | |
| } | |
| net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime) | |
| if pretrained is not None: | |
| assert os.path.isfile(pretrained), pretrained | |
| ckpt = torch.load(pretrained, map_location='cpu') | |
| missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False) | |
| assert all(['dim_reduction' in a for a in missing]), "Loading did not go well" | |
| assert all(['fc' in a for a in unexpected]), "Loading did not go well" | |
| return net | |