Spaces:
Runtime error
Runtime error
| """ Median Pool | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .helpers import to_2tuple, to_4tuple | |
| class MedianPool2d(nn.Module): | |
| """ Median pool (usable as median filter when stride=1) module. | |
| Args: | |
| kernel_size: size of pooling kernel, int or 2-tuple | |
| stride: pool stride, int or 2-tuple | |
| padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad | |
| same: override padding and enforce same padding, boolean | |
| """ | |
| def __init__(self, kernel_size=3, stride=1, padding=0, same=False): | |
| super(MedianPool2d, self).__init__() | |
| self.k = to_2tuple(kernel_size) | |
| self.stride = to_2tuple(stride) | |
| self.padding = to_4tuple(padding) # convert to l, r, t, b | |
| self.same = same | |
| def _padding(self, x): | |
| if self.same: | |
| ih, iw = x.size()[2:] | |
| if ih % self.stride[0] == 0: | |
| ph = max(self.k[0] - self.stride[0], 0) | |
| else: | |
| ph = max(self.k[0] - (ih % self.stride[0]), 0) | |
| if iw % self.stride[1] == 0: | |
| pw = max(self.k[1] - self.stride[1], 0) | |
| else: | |
| pw = max(self.k[1] - (iw % self.stride[1]), 0) | |
| pl = pw // 2 | |
| pr = pw - pl | |
| pt = ph // 2 | |
| pb = ph - pt | |
| padding = (pl, pr, pt, pb) | |
| else: | |
| padding = self.padding | |
| return padding | |
| def forward(self, x): | |
| x = F.pad(x, self._padding(x), mode='reflect') | |
| x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) | |
| x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] | |
| return x | |