Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from .attention import SpatialAttention, TemporalAttention | |
| from .common import ResidualBlock3D | |
| from .gc_block import GlobalContextBlock | |
| from .upsamplers import (SpatialTemporalUpsampler3D, SpatialUpsampler3D, | |
| TemporalUpsampler3D) | |
| def get_up_block( | |
| up_block_type: str, | |
| in_channels: int, | |
| out_channels: int, | |
| num_layers: int, | |
| act_fn: str, | |
| norm_num_groups: int = 32, | |
| norm_eps: float = 1e-6, | |
| dropout: float = 0.0, | |
| num_attention_heads: int = 1, | |
| output_scale_factor: float = 1.0, | |
| add_gc_block: bool = False, | |
| add_upsample: bool = True, | |
| ) -> nn.Module: | |
| if up_block_type == "SpatialUpBlock3D": | |
| return SpatialUpBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| num_layers=num_layers, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| output_scale_factor=output_scale_factor, | |
| add_gc_block=add_gc_block, | |
| add_upsample=add_upsample, | |
| ) | |
| elif up_block_type == "SpatialAttnUpBlock3D": | |
| return SpatialAttnUpBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| num_layers=num_layers, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| attention_head_dim=out_channels // num_attention_heads, | |
| output_scale_factor=output_scale_factor, | |
| add_gc_block=add_gc_block, | |
| add_upsample=add_upsample, | |
| ) | |
| elif up_block_type == "TemporalUpBlock3D": | |
| return TemporalUpBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| num_layers=num_layers, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| output_scale_factor=output_scale_factor, | |
| add_gc_block=add_gc_block, | |
| add_upsample=add_upsample, | |
| ) | |
| elif up_block_type == "TemporalAttnUpBlock3D": | |
| return TemporalAttnUpBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| num_layers=num_layers, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| attention_head_dim=out_channels // num_attention_heads, | |
| output_scale_factor=output_scale_factor, | |
| add_gc_block=add_gc_block, | |
| add_upsample=add_upsample, | |
| ) | |
| elif up_block_type == "SpatialTemporalUpBlock3D": | |
| return SpatialTemporalUpBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| num_layers=num_layers, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| output_scale_factor=output_scale_factor, | |
| add_gc_block=add_gc_block, | |
| add_upsample=add_upsample, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown up block type: {up_block_type}") | |
| class SpatialUpBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| num_layers: int = 1, | |
| act_fn: str = "silu", | |
| norm_num_groups: int = 32, | |
| norm_eps: float = 1e-6, | |
| dropout: float = 0.0, | |
| output_scale_factor: float = 1.0, | |
| add_gc_block: bool = False, | |
| add_upsample: bool = True, | |
| ): | |
| super().__init__() | |
| if add_upsample: | |
| self.upsampler = SpatialUpsampler3D(in_channels, in_channels) | |
| else: | |
| self.upsampler = None | |
| if add_gc_block: | |
| self.gc_block = GlobalContextBlock(in_channels, in_channels, fusion_type="mul") | |
| else: | |
| self.gc_block = None | |
| self.convs = nn.ModuleList([]) | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| self.convs.append( | |
| ResidualBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| non_linearity=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| output_scale_factor=output_scale_factor, | |
| ) | |
| ) | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| for conv in self.convs: | |
| x = conv(x) | |
| if self.gc_block is not None: | |
| x = self.gc_block(x) | |
| if self.upsampler is not None: | |
| x = self.upsampler(x) | |
| return x | |
| class SpatialAttnUpBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| num_layers: int = 1, | |
| act_fn: str = "silu", | |
| norm_num_groups: int = 32, | |
| norm_eps: float = 1e-6, | |
| dropout: float = 0.0, | |
| attention_head_dim: int = 1, | |
| output_scale_factor: float = 1.0, | |
| add_gc_block: bool = False, | |
| add_upsample: bool = True, | |
| ): | |
| super().__init__() | |
| self.convs = nn.ModuleList([]) | |
| self.attentions = nn.ModuleList([]) | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| self.convs.append( | |
| ResidualBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| non_linearity=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| output_scale_factor=output_scale_factor, | |
| ) | |
| ) | |
| self.attentions.append( | |
| SpatialAttention( | |
| out_channels, | |
| nheads=out_channels // attention_head_dim, | |
| head_dim=attention_head_dim, | |
| bias=True, | |
| upcast_softmax=True, | |
| norm_num_groups=norm_num_groups, | |
| eps=norm_eps, | |
| rescale_output_factor=output_scale_factor, | |
| residual_connection=True, | |
| ) | |
| ) | |
| if add_gc_block: | |
| self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul") | |
| else: | |
| self.gc_block = None | |
| if add_upsample: | |
| self.upsampler = SpatialUpsampler3D(out_channels, out_channels) | |
| else: | |
| self.upsampler = None | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| for conv, attn in zip(self.convs, self.attentions): | |
| x = conv(x) | |
| x = attn(x) | |
| if self.gc_block is not None: | |
| x = self.gc_block(x) | |
| if self.upsampler is not None: | |
| x = self.upsampler(x) | |
| return x | |
| class TemporalUpBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| num_layers: int = 1, | |
| act_fn: str = "silu", | |
| norm_num_groups: int = 32, | |
| norm_eps: float = 1e-6, | |
| dropout: float = 0.0, | |
| output_scale_factor: float = 1.0, | |
| add_gc_block: bool = False, | |
| add_upsample: bool = True, | |
| ): | |
| super().__init__() | |
| self.convs = nn.ModuleList([]) | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| self.convs.append( | |
| ResidualBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| non_linearity=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| output_scale_factor=output_scale_factor, | |
| ) | |
| ) | |
| if add_gc_block: | |
| self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul") | |
| else: | |
| self.gc_block = None | |
| if add_upsample: | |
| self.upsampler = TemporalUpsampler3D(out_channels, out_channels) | |
| else: | |
| self.upsampler = None | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| for conv in self.convs: | |
| x = conv(x) | |
| if self.gc_block is not None: | |
| x = self.gc_block(x) | |
| if self.upsampler is not None: | |
| x = self.upsampler(x) | |
| return x | |
| class TemporalAttnUpBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| num_layers: int = 1, | |
| act_fn: str = "silu", | |
| norm_num_groups: int = 32, | |
| norm_eps: float = 1e-6, | |
| dropout: float = 0.0, | |
| attention_head_dim: int = 1, | |
| output_scale_factor: float = 1.0, | |
| add_gc_block: bool = False, | |
| add_upsample: bool = True, | |
| ): | |
| super().__init__() | |
| self.convs = nn.ModuleList([]) | |
| self.attentions = nn.ModuleList([]) | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| self.convs.append( | |
| ResidualBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| non_linearity=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| output_scale_factor=output_scale_factor, | |
| ) | |
| ) | |
| self.attentions.append( | |
| TemporalAttention( | |
| out_channels, | |
| nheads=out_channels // attention_head_dim, | |
| head_dim=attention_head_dim, | |
| bias=True, | |
| upcast_softmax=True, | |
| norm_num_groups=norm_num_groups, | |
| eps=norm_eps, | |
| rescale_output_factor=output_scale_factor, | |
| residual_connection=True, | |
| ) | |
| ) | |
| if add_gc_block: | |
| self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul") | |
| else: | |
| self.gc_block = None | |
| if add_upsample: | |
| self.upsampler = TemporalUpsampler3D(out_channels, out_channels) | |
| else: | |
| self.upsampler = None | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| for conv, attn in zip(self.convs, self.attentions): | |
| x = conv(x) | |
| x = attn(x) | |
| if self.gc_block is not None: | |
| x = self.gc_block(x) | |
| if self.upsampler is not None: | |
| x = self.upsampler(x) | |
| return x | |
| class SpatialTemporalUpBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| num_layers: int = 1, | |
| act_fn: str = "silu", | |
| norm_num_groups: int = 32, | |
| norm_eps: float = 1e-6, | |
| dropout: float = 0.0, | |
| output_scale_factor: float = 1.0, | |
| add_gc_block: bool = False, | |
| add_upsample: bool = True, | |
| ): | |
| super().__init__() | |
| self.convs = nn.ModuleList([]) | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| self.convs.append( | |
| ResidualBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| non_linearity=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=norm_eps, | |
| dropout=dropout, | |
| output_scale_factor=output_scale_factor, | |
| ) | |
| ) | |
| if add_gc_block: | |
| self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul") | |
| else: | |
| self.gc_block = None | |
| if add_upsample: | |
| self.upsampler = SpatialTemporalUpsampler3D(out_channels, out_channels) | |
| else: | |
| self.upsampler = None | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| for conv in self.convs: | |
| x = conv(x) | |
| if self.gc_block is not None: | |
| x = self.gc_block(x) | |
| if self.upsampler is not None: | |
| x = self.upsampler(x) | |
| return x | |