Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, pack, unpack | |
| from .normalize import Normalize | |
| from .ops import nonlinearity, video_to_image | |
| from .conv import CausalConv3d | |
| from .block import Block | |
| class ResnetBlock2D(Block): | |
| def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, | |
| dropout): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels if out_channels is None else out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = Normalize(in_channels) | |
| self.conv1 = torch.nn.Conv2d( | |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| self.norm2 = Normalize(out_channels) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.conv2 = torch.nn.Conv2d( | |
| out_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| self.conv_shortcut = torch.nn.Conv2d( | |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| else: | |
| self.nin_shortcut = torch.nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = nonlinearity(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = nonlinearity(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| x = self.conv_shortcut(x) | |
| else: | |
| x = self.nin_shortcut(x) | |
| x = x + h | |
| return x | |
| class ResnetBlock3D(Block): | |
| def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels if out_channels is None else out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = Normalize(in_channels) | |
| self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) | |
| self.norm2 = Normalize(out_channels) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) | |
| else: | |
| self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = nonlinearity(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = nonlinearity(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| x = self.conv_shortcut(x) | |
| else: | |
| x = self.nin_shortcut(x) | |
| return x + h |