Spaces:
Runtime error
Runtime error
| """ PyTorch Mixed Convolution | |
| Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| import torch | |
| from torch import nn as nn | |
| from .conv2d_same import create_conv2d_pad | |
| def _split_channels(num_chan, num_groups): | |
| split = [num_chan // num_groups for _ in range(num_groups)] | |
| split[0] += num_chan - sum(split) | |
| return split | |
| class MixedConv2d(nn.ModuleDict): | |
| """ Mixed Grouped Convolution | |
| Based on MDConv and GroupedConv in MixNet impl: | |
| https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size=3, | |
| stride=1, padding='', dilation=1, depthwise=False, **kwargs): | |
| super(MixedConv2d, self).__init__() | |
| kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] | |
| num_groups = len(kernel_size) | |
| in_splits = _split_channels(in_channels, num_groups) | |
| out_splits = _split_channels(out_channels, num_groups) | |
| self.in_channels = sum(in_splits) | |
| self.out_channels = sum(out_splits) | |
| for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): | |
| conv_groups = in_ch if depthwise else 1 | |
| # use add_module to keep key space clean | |
| self.add_module( | |
| str(idx), | |
| create_conv2d_pad( | |
| in_ch, out_ch, k, stride=stride, | |
| padding=padding, dilation=dilation, groups=conv_groups, **kwargs) | |
| ) | |
| self.splits = in_splits | |
| def forward(self, x): | |
| x_split = torch.split(x, self.splits, 1) | |
| x_out = [c(x_split[i]) for i, c in enumerate(self.values())] | |
| x = torch.cat(x_out, 1) | |
| return x | |