Spaces:
Running
Running
| from typing import Any, Dict | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.utils import is_torch_version | |
| from einops import rearrange | |
| from ..modules.vaemodules.activations import get_activation | |
| from ..modules.vaemodules.common import CausalConv3d | |
| from ..modules.vaemodules.down_blocks import get_down_block | |
| from ..modules.vaemodules.mid_blocks import get_mid_block | |
| from ..modules.vaemodules.up_blocks import get_up_block | |
| def create_custom_forward(module, return_dict=None): | |
| def custom_forward(*inputs): | |
| if return_dict is not None: | |
| return module(*inputs, return_dict=return_dict) | |
| else: | |
| return module(*inputs) | |
| return custom_forward | |
| class Encoder(nn.Module): | |
| r""" | |
| The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. | |
| Args: | |
| in_channels (`int`, *optional*, defaults to 3): | |
| The number of input channels. | |
| out_channels (`int`, *optional*, defaults to 8): | |
| The number of output channels. | |
| down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialDownBlock3D",)`): | |
| The types of down blocks to use. | |
| block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
| The number of output channels for each block. | |
| use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`): | |
| Whether to use global context blocks for each down block. | |
| mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`): | |
| The type of mid block to use. | |
| layers_per_block (`int`, *optional*, defaults to 2): | |
| The number of layers per block. | |
| norm_num_groups (`int`, *optional*, defaults to 32): | |
| The number of groups for normalization. | |
| act_fn (`str`, *optional*, defaults to `"silu"`): | |
| The activation function to use. See `~diffusers.models.activations.get_activation` for available options. | |
| num_attention_heads (`int`, *optional*, defaults to 1): | |
| The number of attention heads to use. | |
| double_z (`bool`, *optional*, defaults to `True`): | |
| Whether to double the number of output channels for the last block. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| out_channels: int = 8, | |
| down_block_types = ("SpatialDownBlock3D",), | |
| ch = 128, | |
| ch_mult = [1,2,4,4,], | |
| block_out_channels = [128, 256, 512, 512], | |
| use_gc_blocks = None, | |
| mid_block_type: str = "MidBlock3D", | |
| mid_block_use_attention: bool = True, | |
| mid_block_attention_type: str = "3d", | |
| mid_block_num_attention_heads: int = 1, | |
| layers_per_block: int = 2, | |
| norm_num_groups: int = 32, | |
| act_fn: str = "silu", | |
| num_attention_heads: int = 1, | |
| double_z: bool = True, | |
| slice_mag_vae: bool = False, | |
| slice_compression_vae: bool = False, | |
| cache_compression_vae: bool = False, | |
| cache_mag_vae: bool = False, | |
| spatial_group_norm: bool = False, | |
| mini_batch_encoder: int = 9, | |
| verbose = False, | |
| ): | |
| super().__init__() | |
| if block_out_channels is None: | |
| block_out_channels = [ch * i for i in ch_mult] | |
| assert len(down_block_types) == len(block_out_channels), ( | |
| "Number of down block types must match number of block output channels." | |
| ) | |
| if use_gc_blocks is not None: | |
| assert len(use_gc_blocks) == len(down_block_types), ( | |
| "Number of GC blocks must match number of down block types." | |
| ) | |
| else: | |
| use_gc_blocks = [False] * len(down_block_types) | |
| self.conv_in = CausalConv3d( | |
| in_channels, | |
| block_out_channels[0], | |
| kernel_size=3, | |
| ) | |
| self.down_blocks = nn.ModuleList([]) | |
| output_channels = block_out_channels[0] | |
| for i, down_block_type in enumerate(down_block_types): | |
| input_channels = output_channels | |
| output_channels = block_out_channels[i] | |
| is_final_block = (i == len(block_out_channels) - 1) | |
| down_block = get_down_block( | |
| down_block_type, | |
| in_channels=input_channels, | |
| out_channels=output_channels, | |
| num_layers=layers_per_block, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=1e-6, | |
| num_attention_heads=num_attention_heads, | |
| add_gc_block=use_gc_blocks[i], | |
| add_downsample=not is_final_block, | |
| ) | |
| self.down_blocks.append(down_block) | |
| self.mid_block = get_mid_block( | |
| mid_block_type, | |
| in_channels=block_out_channels[-1], | |
| num_layers=layers_per_block, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=1e-6, | |
| add_attention=mid_block_use_attention, | |
| attention_type=mid_block_attention_type, | |
| num_attention_heads=mid_block_num_attention_heads, | |
| ) | |
| self.conv_norm_out = nn.GroupNorm( | |
| num_channels=block_out_channels[-1], | |
| num_groups=norm_num_groups, | |
| eps=1e-6, | |
| ) | |
| self.conv_act = get_activation(act_fn) | |
| conv_out_channels = 2 * out_channels if double_z else out_channels | |
| self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) | |
| self.slice_mag_vae = slice_mag_vae | |
| self.slice_compression_vae = slice_compression_vae | |
| self.cache_compression_vae = cache_compression_vae | |
| self.cache_mag_vae = cache_mag_vae | |
| self.mini_batch_encoder = mini_batch_encoder | |
| self.spatial_group_norm = spatial_group_norm | |
| self.verbose = verbose | |
| def set_padding_one_frame(self): | |
| def _set_padding_one_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 1 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_padding_one_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_padding_one_frame(name, module) | |
| def set_padding_more_frame(self): | |
| def _set_padding_more_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 2 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_padding_more_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_padding_more_frame(name, module) | |
| def set_magvit_padding_one_frame(self): | |
| def _set_magvit_padding_one_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 3 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_magvit_padding_one_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_magvit_padding_one_frame(name, module) | |
| def set_magvit_padding_more_frame(self): | |
| def _set_magvit_padding_more_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 4 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_magvit_padding_more_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_magvit_padding_more_frame(name, module) | |
| def set_cache_slice_vae_padding_one_frame(self): | |
| def _set_cache_slice_vae_padding_one_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 5 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_cache_slice_vae_padding_one_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_cache_slice_vae_padding_one_frame(name, module) | |
| def set_cache_slice_vae_padding_more_frame(self): | |
| def _set_cache_slice_vae_padding_more_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 6 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_cache_slice_vae_padding_more_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_cache_slice_vae_padding_more_frame(name, module) | |
| def set_3dgroupnorm_for_submodule(self): | |
| def _set_3dgroupnorm_for_submodule(name, module): | |
| if hasattr(module, 'set_3dgroupnorm'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.set_3dgroupnorm = True | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_3dgroupnorm_for_submodule(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_3dgroupnorm_for_submodule(name, module) | |
| def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor: | |
| # x: (B, C, T, H, W) | |
| if self.training: | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| if previous_features is not None and after_features is None: | |
| x = torch.concat([previous_features, x], 2) | |
| elif previous_features is None and after_features is not None: | |
| x = torch.concat([x, after_features], 2) | |
| elif previous_features is not None and after_features is not None: | |
| x = torch.concat([previous_features, x, after_features], 2) | |
| if self.training: | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(self.conv_in), | |
| x, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| x = self.conv_in(x) | |
| for down_block in self.down_blocks: | |
| if self.training: | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(down_block), | |
| x, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| x = down_block(x) | |
| x = self.mid_block(x) | |
| if self.spatial_group_norm: | |
| batch_size = x.shape[0] | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = self.conv_norm_out(x) | |
| x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) | |
| else: | |
| x = self.conv_norm_out(x) | |
| x = self.conv_act(x) | |
| x = self.conv_out(x) | |
| if previous_features is not None and after_features is None: | |
| x = x[:, :, 1:] | |
| elif previous_features is None and after_features is not None: | |
| x = x[:, :, :2] | |
| elif previous_features is not None and after_features is not None: | |
| x = x[:, :, 1:3] | |
| return x | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.spatial_group_norm: | |
| self.set_3dgroupnorm_for_submodule() | |
| if self.cache_mag_vae: | |
| self.set_magvit_padding_one_frame() | |
| first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) | |
| self.set_magvit_padding_more_frame() | |
| new_pixel_values = [first_frames] | |
| for i in range(1, x.shape[2], self.mini_batch_encoder): | |
| next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None) | |
| new_pixel_values.append(next_frames) | |
| new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
| elif self.cache_compression_vae: | |
| _, _, f, _, _ = x.size() | |
| if f % 2 != 0: | |
| self.set_padding_one_frame() | |
| first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) | |
| self.set_padding_more_frame() | |
| new_pixel_values = [first_frames] | |
| start_index = 1 | |
| else: | |
| self.set_padding_more_frame() | |
| new_pixel_values = [] | |
| start_index = 0 | |
| for i in range(start_index, x.shape[2], self.mini_batch_encoder): | |
| next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None) | |
| new_pixel_values.append(next_frames) | |
| new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
| elif self.slice_compression_vae: | |
| _, _, f, _, _ = x.size() | |
| if f % 2 != 0: | |
| self.set_padding_one_frame() | |
| first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) | |
| self.set_padding_more_frame() | |
| new_pixel_values = [first_frames] | |
| start_index = 1 | |
| else: | |
| self.set_padding_more_frame() | |
| new_pixel_values = [] | |
| start_index = 0 | |
| for i in range(start_index, x.shape[2], self.mini_batch_encoder): | |
| next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None) | |
| new_pixel_values.append(next_frames) | |
| new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
| elif self.slice_mag_vae: | |
| _, _, f, _, _ = x.size() | |
| new_pixel_values = [] | |
| for i in range(0, x.shape[2], self.mini_batch_encoder): | |
| next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None) | |
| new_pixel_values.append(next_frames) | |
| new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
| else: | |
| new_pixel_values = self.single_forward(x, None, None) | |
| return new_pixel_values | |
| class Decoder(nn.Module): | |
| r""" | |
| The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. | |
| Args: | |
| in_channels (`int`, *optional*, defaults to 8): | |
| The number of input channels. | |
| out_channels (`int`, *optional*, defaults to 3): | |
| The number of output channels. | |
| up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialUpBlock3D",)`): | |
| The types of up blocks to use. | |
| block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
| The number of output channels for each block. | |
| use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`): | |
| Whether to use global context blocks for each down block. | |
| mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`): | |
| The type of mid block to use. | |
| layers_per_block (`int`, *optional*, defaults to 2): | |
| The number of layers per block. | |
| norm_num_groups (`int`, *optional*, defaults to 32): | |
| The number of groups for normalization. | |
| act_fn (`str`, *optional*, defaults to `"silu"`): | |
| The activation function to use. See `~diffusers.models.activations.get_activation` for available options. | |
| num_attention_heads (`int`, *optional*, defaults to 1): | |
| The number of attention heads to use. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 8, | |
| out_channels: int = 3, | |
| up_block_types = ("SpatialUpBlock3D",), | |
| ch = 128, | |
| ch_mult = [1,2,4,4,], | |
| block_out_channels = [128, 256, 512, 512], | |
| use_gc_blocks = None, | |
| mid_block_type: str = "MidBlock3D", | |
| mid_block_use_attention: bool = True, | |
| mid_block_attention_type: str = "3d", | |
| mid_block_num_attention_heads: int = 1, | |
| layers_per_block: int = 2, | |
| norm_num_groups: int = 32, | |
| act_fn: str = "silu", | |
| num_attention_heads: int = 1, | |
| slice_mag_vae: bool = False, | |
| slice_compression_vae: bool = False, | |
| cache_compression_vae: bool = False, | |
| cache_mag_vae: bool = False, | |
| spatial_group_norm: bool = False, | |
| mini_batch_decoder: int = 3, | |
| verbose = False, | |
| ): | |
| super().__init__() | |
| if block_out_channels is None: | |
| block_out_channels = [ch * i for i in ch_mult] | |
| assert len(up_block_types) == len(block_out_channels), ( | |
| "Number of up block types must match number of block output channels." | |
| ) | |
| if use_gc_blocks is not None: | |
| assert len(use_gc_blocks) == len(up_block_types), ( | |
| "Number of GC blocks must match number of up block types." | |
| ) | |
| else: | |
| use_gc_blocks = [False] * len(up_block_types) | |
| self.conv_in = CausalConv3d( | |
| in_channels, | |
| block_out_channels[-1], | |
| kernel_size=3, | |
| ) | |
| self.mid_block = get_mid_block( | |
| mid_block_type, | |
| in_channels=block_out_channels[-1], | |
| num_layers=layers_per_block, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=1e-6, | |
| add_attention=mid_block_use_attention, | |
| attention_type=mid_block_attention_type, | |
| num_attention_heads=mid_block_num_attention_heads, | |
| ) | |
| self.up_blocks = nn.ModuleList([]) | |
| reversed_block_out_channels = list(reversed(block_out_channels)) | |
| output_channels = reversed_block_out_channels[0] | |
| for i, up_block_type in enumerate(up_block_types): | |
| input_channels = output_channels | |
| output_channels = reversed_block_out_channels[i] | |
| # is_first_block = i == 0 | |
| is_final_block = i == len(block_out_channels) - 1 | |
| up_block = get_up_block( | |
| up_block_type, | |
| in_channels=input_channels, | |
| out_channels=output_channels, | |
| num_layers=layers_per_block + 1, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| norm_eps=1e-6, | |
| num_attention_heads=num_attention_heads, | |
| add_gc_block=use_gc_blocks[i], | |
| add_upsample=not is_final_block, | |
| ) | |
| self.up_blocks.append(up_block) | |
| self.conv_norm_out = nn.GroupNorm( | |
| num_channels=block_out_channels[0], | |
| num_groups=norm_num_groups, | |
| eps=1e-6, | |
| ) | |
| self.conv_act = get_activation(act_fn) | |
| self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) | |
| self.slice_mag_vae = slice_mag_vae | |
| self.slice_compression_vae = slice_compression_vae | |
| self.cache_compression_vae = cache_compression_vae | |
| self.cache_mag_vae = cache_mag_vae | |
| self.mini_batch_decoder = mini_batch_decoder | |
| self.spatial_group_norm = spatial_group_norm | |
| self.verbose = verbose | |
| def set_padding_one_frame(self): | |
| def _set_padding_one_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 1 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_padding_one_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_padding_one_frame(name, module) | |
| def set_padding_more_frame(self): | |
| def _set_padding_more_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 2 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_padding_more_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_padding_more_frame(name, module) | |
| def set_magvit_padding_one_frame(self): | |
| def _set_magvit_padding_one_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 3 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_magvit_padding_one_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_magvit_padding_one_frame(name, module) | |
| def set_magvit_padding_more_frame(self): | |
| def _set_magvit_padding_more_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 4 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_magvit_padding_more_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_magvit_padding_more_frame(name, module) | |
| def set_cache_slice_vae_padding_one_frame(self): | |
| def _set_cache_slice_vae_padding_one_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 5 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_cache_slice_vae_padding_one_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_cache_slice_vae_padding_one_frame(name, module) | |
| def set_cache_slice_vae_padding_more_frame(self): | |
| def _set_cache_slice_vae_padding_more_frame(name, module): | |
| if hasattr(module, 'padding_flag'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.padding_flag = 6 | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_cache_slice_vae_padding_more_frame(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_cache_slice_vae_padding_more_frame(name, module) | |
| def set_3dgroupnorm_for_submodule(self): | |
| def _set_3dgroupnorm_for_submodule(name, module): | |
| if hasattr(module, 'set_3dgroupnorm'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.set_3dgroupnorm = True | |
| for sub_name, sub_mod in module.named_children(): | |
| _set_3dgroupnorm_for_submodule(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _set_3dgroupnorm_for_submodule(name, module) | |
| def clear_cache(self): | |
| def _clear_cache(name, module): | |
| if hasattr(module, 'prev_features'): | |
| if self.verbose: | |
| print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
| module.prev_features = None | |
| for sub_name, sub_mod in module.named_children(): | |
| _clear_cache(sub_name, sub_mod) | |
| for name, module in self.named_children(): | |
| _clear_cache(name, module) | |
| def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor: | |
| # x: (B, C, T, H, W) | |
| if self.training: | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| if previous_features is not None and after_features is None: | |
| b, c, t, h, w = x.size() | |
| x = torch.concat([previous_features, x], 2) | |
| x = self.conv_in(x) | |
| x = self.mid_block(x) | |
| x = x[:, :, -t:] | |
| elif previous_features is None and after_features is not None: | |
| b, c, t, h, w = x.size() | |
| x = torch.concat([x, after_features], 2) | |
| x = self.conv_in(x) | |
| x = self.mid_block(x) | |
| x = x[:, :, :t] | |
| elif previous_features is not None and after_features is not None: | |
| _, _, t_1, _, _ = previous_features.size() | |
| _, _, t_2, _, _ = x.size() | |
| x = torch.concat([previous_features, x, after_features], 2) | |
| x = self.conv_in(x) | |
| x = self.mid_block(x) | |
| x = x[:, :, t_1:(t_1 + t_2)] | |
| else: | |
| if self.training: | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(self.conv_in), | |
| x, | |
| **ckpt_kwargs, | |
| ) | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(self.mid_block), | |
| x, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| x = self.conv_in(x) | |
| x = self.mid_block(x) | |
| for up_block in self.up_blocks: | |
| if self.training: | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(up_block), | |
| x, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| x = up_block(x) | |
| if self.spatial_group_norm: | |
| batch_size = x.shape[0] | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = self.conv_norm_out(x) | |
| x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) | |
| else: | |
| x = self.conv_norm_out(x) | |
| x = self.conv_act(x) | |
| x = self.conv_out(x) | |
| return x | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.spatial_group_norm: | |
| self.set_3dgroupnorm_for_submodule() | |
| if self.cache_mag_vae: | |
| self.set_magvit_padding_one_frame() | |
| first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) | |
| self.set_magvit_padding_more_frame() | |
| new_pixel_values = [first_frames] | |
| for i in range(1, x.shape[2], self.mini_batch_decoder): | |
| next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None) | |
| new_pixel_values.append(next_frames) | |
| new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
| elif self.cache_compression_vae: | |
| _, _, f, _, _ = x.size() | |
| if f == 1: | |
| self.set_padding_one_frame() | |
| first_frames = self.single_forward(x[:, :, :1, :, :], None, None) | |
| new_pixel_values = [first_frames] | |
| start_index = 1 | |
| else: | |
| self.set_cache_slice_vae_padding_one_frame() | |
| first_frames = self.single_forward(x[:, :, :self.mini_batch_decoder, :, :], None, None) | |
| new_pixel_values = [first_frames] | |
| start_index = self.mini_batch_decoder | |
| for i in range(start_index, x.shape[2], self.mini_batch_decoder): | |
| self.set_cache_slice_vae_padding_more_frame() | |
| next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None) | |
| new_pixel_values.append(next_frames) | |
| new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
| elif self.slice_compression_vae: | |
| _, _, f, _, _ = x.size() | |
| if f % 2 != 0: | |
| self.set_padding_one_frame() | |
| first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) | |
| self.set_padding_more_frame() | |
| new_pixel_values = [first_frames] | |
| start_index = 1 | |
| else: | |
| self.set_padding_more_frame() | |
| new_pixel_values = [] | |
| start_index = 0 | |
| previous_features = None | |
| for i in range(start_index, x.shape[2], self.mini_batch_decoder): | |
| after_features = x[:, :, i + self.mini_batch_decoder: i + 2 * self.mini_batch_decoder, :, :] if i + self.mini_batch_decoder < x.shape[2] else None | |
| next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], previous_features, after_features) | |
| previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :] | |
| new_pixel_values.append(next_frames) | |
| new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
| elif self.slice_mag_vae: | |
| _, _, f, _, _ = x.size() | |
| new_pixel_values = [] | |
| for i in range(0, x.shape[2], self.mini_batch_decoder): | |
| next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None) | |
| new_pixel_values.append(next_frames) | |
| new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
| else: | |
| new_pixel_values = self.single_forward(x, None, None) | |
| return new_pixel_values | |