Spaces:
Running
Running
| # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. | |
| # All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.models.autoencoders.autoencoder_kl_cogvideox import ( | |
| CogVideoXCausalConv3d, CogVideoXDownBlock3D, CogVideoXMidBlock3D, | |
| CogVideoXSafeConv3d, CogVideoXSpatialNorm3D, CogVideoXUpBlock3D) | |
| from diffusers.utils import logging | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class CogVideoXEncoder3D(nn.Module): | |
| r""" | |
| The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. | |
| Args: | |
| in_channels (`int`, *optional*, defaults to 3): | |
| The number of input channels. | |
| out_channels (`int`, *optional*, defaults to 3): | |
| The number of output channels. | |
| down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): | |
| The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available | |
| options. | |
| block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
| The number of output channels for each block. | |
| act_fn (`str`, *optional*, defaults to `"silu"`): | |
| The activation function to use. See `~diffusers.models.activations.get_activation` for available options. | |
| layers_per_block (`int`, *optional*, defaults to 2): | |
| The number of layers per block. | |
| norm_num_groups (`int`, *optional*, defaults to 32): | |
| The number of groups for normalization. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| out_channels: int = 16, | |
| down_block_types: Tuple[str, ...] = ( | |
| "CogVideoXDownBlock3D", | |
| "CogVideoXDownBlock3D", | |
| "CogVideoXDownBlock3D", | |
| "CogVideoXDownBlock3D", | |
| ), | |
| block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), | |
| layers_per_block: int = 3, | |
| act_fn: str = "silu", | |
| norm_eps: float = 1e-6, | |
| norm_num_groups: int = 32, | |
| dropout: float = 0.0, | |
| pad_mode: str = "first", | |
| temporal_compression_ratio: float = 4, | |
| ): | |
| super().__init__() | |
| # log2 of temporal_compress_times | |
| temporal_compress_level = int(np.log2(temporal_compression_ratio)) | |
| self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) | |
| self.down_blocks = nn.ModuleList([]) | |
| # down blocks | |
| output_channel = block_out_channels[0] | |
| for i, down_block_type in enumerate(down_block_types): | |
| input_channel = output_channel | |
| output_channel = block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| compress_time = i < temporal_compress_level | |
| if down_block_type == "CogVideoXDownBlock3D": | |
| down_block = CogVideoXDownBlock3D( | |
| in_channels=input_channel, | |
| out_channels=output_channel, | |
| temb_channels=0, | |
| dropout=dropout, | |
| num_layers=layers_per_block, | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| add_downsample=not is_final_block, | |
| compress_time=compress_time, | |
| ) | |
| else: | |
| raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") | |
| self.down_blocks.append(down_block) | |
| # mid block | |
| self.mid_block = CogVideoXMidBlock3D( | |
| in_channels=block_out_channels[-1], | |
| temb_channels=0, | |
| dropout=dropout, | |
| num_layers=2, | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| pad_mode=pad_mode, | |
| ) | |
| self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) | |
| self.conv_act = nn.SiLU() | |
| self.conv_out = CogVideoXCausalConv3d( | |
| block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode | |
| ) | |
| self.gradient_checkpointing = False | |
| def _clear_fake_context_parallel_cache(self): | |
| for name, module in self.named_modules(): | |
| if isinstance(module, CogVideoXCausalConv3d): | |
| logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") | |
| module._clear_fake_context_parallel_cache() | |
| def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| r"""The forward method of the `CogVideoXEncoder3D` class.""" | |
| hidden_states = self.conv_in(sample) | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| # 1. Down | |
| for down_block in self.down_blocks: | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(down_block), hidden_states, temb, None | |
| ) | |
| # 2. Mid | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(self.mid_block), hidden_states, temb, None | |
| ) | |
| else: | |
| # 1. Down | |
| for down_block in self.down_blocks: | |
| hidden_states = down_block(hidden_states, temb, None) | |
| # 2. Mid | |
| hidden_states = self.mid_block(hidden_states, temb, None) | |
| # 3. Post-process | |
| hidden_states = self.norm_out(hidden_states) | |
| hidden_states = self.conv_act(hidden_states) | |
| hidden_states = self.conv_out(hidden_states) | |
| return hidden_states | |
| class CogVideoXDecoder3D(nn.Module): | |
| r""" | |
| The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output | |
| sample. | |
| Args: | |
| in_channels (`int`, *optional*, defaults to 3): | |
| The number of input channels. | |
| out_channels (`int`, *optional*, defaults to 3): | |
| The number of output channels. | |
| up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): | |
| The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. | |
| block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
| The number of output channels for each block. | |
| act_fn (`str`, *optional*, defaults to `"silu"`): | |
| The activation function to use. See `~diffusers.models.activations.get_activation` for available options. | |
| layers_per_block (`int`, *optional*, defaults to 2): | |
| The number of layers per block. | |
| norm_num_groups (`int`, *optional*, defaults to 32): | |
| The number of groups for normalization. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| def __init__( | |
| self, | |
| in_channels: int = 16, | |
| out_channels: int = 3, | |
| up_block_types: Tuple[str, ...] = ( | |
| "CogVideoXUpBlock3D", | |
| "CogVideoXUpBlock3D", | |
| "CogVideoXUpBlock3D", | |
| "CogVideoXUpBlock3D", | |
| ), | |
| block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), | |
| layers_per_block: int = 3, | |
| act_fn: str = "silu", | |
| norm_eps: float = 1e-6, | |
| norm_num_groups: int = 32, | |
| dropout: float = 0.0, | |
| pad_mode: str = "first", | |
| temporal_compression_ratio: float = 4, | |
| ): | |
| super().__init__() | |
| reversed_block_out_channels = list(reversed(block_out_channels)) | |
| self.conv_in = CogVideoXCausalConv3d( | |
| in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode | |
| ) | |
| # mid block | |
| self.mid_block = CogVideoXMidBlock3D( | |
| in_channels=reversed_block_out_channels[0], | |
| temb_channels=0, | |
| num_layers=2, | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| spatial_norm_dim=in_channels, | |
| pad_mode=pad_mode, | |
| ) | |
| # up blocks | |
| self.up_blocks = nn.ModuleList([]) | |
| output_channel = reversed_block_out_channels[0] | |
| temporal_compress_level = int(np.log2(temporal_compression_ratio)) | |
| for i, up_block_type in enumerate(up_block_types): | |
| prev_output_channel = output_channel | |
| output_channel = reversed_block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| compress_time = i < temporal_compress_level | |
| if up_block_type == "CogVideoXUpBlock3D": | |
| up_block = CogVideoXUpBlock3D( | |
| in_channels=prev_output_channel, | |
| out_channels=output_channel, | |
| temb_channels=0, | |
| dropout=dropout, | |
| num_layers=layers_per_block + 1, | |
| resnet_eps=norm_eps, | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| spatial_norm_dim=in_channels, | |
| add_upsample=not is_final_block, | |
| compress_time=compress_time, | |
| pad_mode=pad_mode, | |
| ) | |
| prev_output_channel = output_channel | |
| else: | |
| raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") | |
| self.up_blocks.append(up_block) | |
| self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) | |
| self.conv_act = nn.SiLU() | |
| self.conv_out = CogVideoXCausalConv3d( | |
| reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode | |
| ) | |
| self.gradient_checkpointing = False | |
| def _clear_fake_context_parallel_cache(self): | |
| for name, module in self.named_modules(): | |
| if isinstance(module, CogVideoXCausalConv3d): | |
| logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") | |
| module._clear_fake_context_parallel_cache() | |
| def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| r"""The forward method of the `CogVideoXDecoder3D` class.""" | |
| hidden_states = self.conv_in(sample) | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| # 1. Mid | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(self.mid_block), hidden_states, temb, sample | |
| ) | |
| # 2. Up | |
| for up_block in self.up_blocks: | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(up_block), hidden_states, temb, sample | |
| ) | |
| else: | |
| # 1. Mid | |
| hidden_states = self.mid_block(hidden_states, temb, sample) | |
| # 2. Up | |
| for up_block in self.up_blocks: | |
| hidden_states = up_block(hidden_states, temb, sample) | |
| # 3. Post-process | |
| hidden_states = self.norm_out(hidden_states, sample) | |
| hidden_states = self.conv_act(hidden_states) | |
| hidden_states = self.conv_out(hidden_states) | |
| return hidden_states |