Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torchvision import models | |
| from einops import rearrange | |
| from torchvision.models._utils import IntermediateLayerGetter | |
| class Vgg(nn.Module): | |
| def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5): | |
| super(Vgg, self).__init__() | |
| if name == 'vgg11_bn': | |
| cnn = models.vgg11_bn(weights='DEFAULT') | |
| elif name == 'vgg19_bn': | |
| cnn = models.vgg19_bn(weights='DEFAULT') | |
| pool_idx = 0 | |
| for i, layer in enumerate(cnn.features): | |
| if isinstance(layer, torch.nn.MaxPool2d): | |
| cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0) | |
| pool_idx += 1 | |
| self.features = cnn.features | |
| self.dropout = nn.Dropout(dropout) | |
| self.last_conv_1x1 = nn.Conv2d(512, hidden, 1) | |
| def forward(self, x): | |
| """ | |
| Shape: | |
| - x: (N, C, H, W) | |
| - output: (W, N, C) | |
| """ | |
| conv = self.features(x) | |
| conv = self.dropout(conv) | |
| conv = self.last_conv_1x1(conv) | |
| # conv = rearrange(conv, 'b d h w -> b d (w h)') | |
| conv = conv.transpose(-1, -2) | |
| conv = conv.flatten(2) | |
| conv = conv.permute(-1, 0, 1) | |
| return conv | |
| def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5): | |
| return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout) | |
| def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5): | |
| return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout) | |