Spaces:
Runtime error
Runtime error
| """ AvgPool2d w/ Same Padding | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import List, Tuple, Optional | |
| from .helpers import to_2tuple | |
| from .padding import pad_same, get_padding_value | |
| def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), | |
| ceil_mode: bool = False, count_include_pad: bool = True): | |
| # FIXME how to deal with count_include_pad vs not for external padding? | |
| x = pad_same(x, kernel_size, stride) | |
| return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) | |
| class AvgPool2dSame(nn.AvgPool2d): | |
| """ Tensorflow like 'SAME' wrapper for 2D average pooling | |
| """ | |
| def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): | |
| kernel_size = to_2tuple(kernel_size) | |
| stride = to_2tuple(stride) | |
| super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) | |
| def forward(self, x): | |
| x = pad_same(x, self.kernel_size, self.stride) | |
| return F.avg_pool2d( | |
| x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) | |
| def max_pool2d_same( | |
| x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), | |
| dilation: List[int] = (1, 1), ceil_mode: bool = False): | |
| x = pad_same(x, kernel_size, stride, value=-float('inf')) | |
| return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) | |
| class MaxPool2dSame(nn.MaxPool2d): | |
| """ Tensorflow like 'SAME' wrapper for 2D max pooling | |
| """ | |
| def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): | |
| kernel_size = to_2tuple(kernel_size) | |
| stride = to_2tuple(stride) | |
| dilation = to_2tuple(dilation) | |
| super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) | |
| def forward(self, x): | |
| x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) | |
| return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) | |
| def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): | |
| stride = stride or kernel_size | |
| padding = kwargs.pop('padding', '') | |
| padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) | |
| if is_dynamic: | |
| if pool_type == 'avg': | |
| return AvgPool2dSame(kernel_size, stride=stride, **kwargs) | |
| elif pool_type == 'max': | |
| return MaxPool2dSame(kernel_size, stride=stride, **kwargs) | |
| else: | |
| assert False, f'Unsupported pool type {pool_type}' | |
| else: | |
| if pool_type == 'avg': | |
| return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) | |
| elif pool_type == 'max': | |
| return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) | |
| else: | |
| assert False, f'Unsupported pool type {pool_type}' | |