Spaces:
Runtime error
Runtime error
| """ Test Time Pooling (Average-Max Pool) | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| import logging | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from .adaptive_avgmax_pool import adaptive_avgmax_pool2d | |
| _logger = logging.getLogger(__name__) | |
| class TestTimePoolHead(nn.Module): | |
| def __init__(self, base, original_pool=7): | |
| super(TestTimePoolHead, self).__init__() | |
| self.base = base | |
| self.original_pool = original_pool | |
| base_fc = self.base.get_classifier() | |
| if isinstance(base_fc, nn.Conv2d): | |
| self.fc = base_fc | |
| else: | |
| self.fc = nn.Conv2d( | |
| self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) | |
| self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) | |
| self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) | |
| self.base.reset_classifier(0) # delete original fc layer | |
| def forward(self, x): | |
| x = self.base.forward_features(x) | |
| x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) | |
| x = self.fc(x) | |
| x = adaptive_avgmax_pool2d(x, 1) | |
| return x.view(x.size(0), -1) | |
| def apply_test_time_pool(model, config, use_test_size=True): | |
| test_time_pool = False | |
| if not hasattr(model, 'default_cfg') or not model.default_cfg: | |
| return model, False | |
| if use_test_size and 'test_input_size' in model.default_cfg: | |
| df_input_size = model.default_cfg['test_input_size'] | |
| else: | |
| df_input_size = model.default_cfg['input_size'] | |
| if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: | |
| _logger.info('Target input size %s > pretrained default %s, using test time pooling' % | |
| (str(config['input_size'][-2:]), str(df_input_size[-2:]))) | |
| model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) | |
| test_time_pool = True | |
| return model, test_time_pool | |