Spaces:
Build error
Build error
| """Module of the HOW method""" | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| class HOWNet(nn.Module): | |
| """Network for the HOW method | |
| :param list features: A list of torch.nn.Module which act as feature extractor | |
| :param torch.nn.Module attention: Attention layer | |
| :param torch.nn.Module smoothing: Smoothing layer | |
| :param torch.nn.Module dim_reduction: Dimensionality reduction layer | |
| :param dict meta: Metadata that are stored with the network | |
| :param dict runtime: Runtime options that can be used as default for e.g. inference | |
| """ | |
| def __init__(self, features, attention, smoothing, dim_reduction, meta, runtime): | |
| super().__init__() | |
| self.features = features | |
| self.attention = attention | |
| self.smoothing = smoothing | |
| self.dim_reduction = dim_reduction | |
| self.meta = meta | |
| self.runtime = runtime | |
| def copy_excluding_dim_reduction(self): | |
| """Return a copy of this network without the dim_reduction layer""" | |
| meta = {**self.meta, "outputdim": self.meta['backbone_dim']} | |
| return self.__class__(self.features, self.attention, self.smoothing, None, meta, self.runtime) | |
| def copy_with_runtime(self, runtime): | |
| """Return a copy of this network with a different runtime dict""" | |
| return self.__class__(self.features, self.attention, self.smoothing, self.dim_reduction, self.meta, runtime) | |
| # Methods of nn.Module | |
| def _set_batchnorm_eval(mod): | |
| if mod.__class__.__name__.find('BatchNorm') != -1: | |
| # freeze running mean and std | |
| mod.eval() | |
| def train(self, mode=True): | |
| res = super().train(mode) | |
| if mode: | |
| self.apply(HOWNet._set_batchnorm_eval) | |
| return res | |
| def parameter_groups(self, optimizer_opts): | |
| """Return torch parameter groups""" | |
| layers = [self.features, self.attention, self.smoothing] | |
| parameters = [{'params': x.parameters()} for x in layers if x is not None] | |
| if self.dim_reduction: | |
| # Do not update dimensionality reduction layer | |
| parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0}) | |
| return parameters | |
| # Forward | |
| def features_attentions(self, x, *, scales): | |
| """Return a tuple (features, attentions) where each is a list containing requested scales""" | |
| feats = [] | |
| masks = [] | |
| for s in scales: | |
| xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False) | |
| o = self.features(xs) | |
| m = self.attention(o) | |
| if self.smoothing: | |
| o = self.smoothing(o) | |
| if self.dim_reduction: | |
| o = self.dim_reduction(o) | |
| feats.append(o) | |
| masks.append(m) | |
| # Normalize max weight to 1 | |
| mx = max(x.max() for x in masks) | |
| masks = [x/mx for x in masks] | |
| return feats, masks | |
| def __repr__(self): | |
| meta_str = "\n".join(" %s: %s" % x for x in self.meta.items()) | |
| return "%s(meta={\n%s\n})" % (self.__class__.__name__, meta_str) | |
| def meta_repr(self): | |
| """Return meta representation""" | |
| return str(self) | |