EuuIia commited on
Commit
3f2c798
·
verified ·
1 Parent(s): ce15b1e

Upload causal_video_autoencoder.py

Browse files
LTX-Video/ltx_video/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -0,0 +1,1448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import partial
4
+ from types import SimpleNamespace
5
+ from typing import Any, Mapping, Optional, Tuple, Union, List
6
+ from pathlib import Path
7
+
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import numpy as np
10
+ import os
11
+
12
+ import torch
13
+ import numpy as np
14
+ from einops import rearrange
15
+ from torch import nn
16
+ from diffusers.utils import logging
17
+ import torch.nn.functional as F
18
+ from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
19
+ from safetensors import safe_open
20
+
21
+
22
+ from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
23
+ from ltx_video.models.autoencoders.pixel_norm import PixelNorm
24
+ from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
25
+ from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
26
+ from ltx_video.models.transformers.attention import Attention
27
+ from ltx_video.utils.diffusers_config_mapping import (
28
+ diffusers_and_ours_config_mapping,
29
+ make_hashable_key,
30
+ VAE_KEYS_RENAME_DICT,
31
+ )
32
+
33
+ PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ class CausalVideoAutoencoder(AutoencoderKLWrapper):
38
+ @classmethod
39
+ def from_pretrained(
40
+ cls,
41
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
42
+ *args,
43
+ **kwargs,
44
+ ):
45
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
46
+ if (
47
+ pretrained_model_name_or_path.is_dir()
48
+ and (pretrained_model_name_or_path / "autoencoder.pth").exists()
49
+ ):
50
+ config_local_path = pretrained_model_name_or_path / "config.json"
51
+ config = cls.load_config(config_local_path, **kwargs)
52
+
53
+ model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
54
+ state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
55
+
56
+ statistics_local_path = (
57
+ pretrained_model_name_or_path / "per_channel_statistics.json"
58
+ )
59
+ if statistics_local_path.exists():
60
+ with open(statistics_local_path, "r") as file:
61
+ data = json.load(file)
62
+ transposed_data = list(zip(*data["data"]))
63
+ data_dict = {
64
+ col: torch.tensor(vals)
65
+ for col, vals in zip(data["columns"], transposed_data)
66
+ }
67
+ std_of_means = data_dict["std-of-means"]
68
+ mean_of_means = data_dict.get(
69
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
70
+ )
71
+ state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
72
+ std_of_means
73
+ )
74
+ state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
75
+ mean_of_means
76
+ )
77
+
78
+ elif pretrained_model_name_or_path.is_dir():
79
+ config_path = pretrained_model_name_or_path / "vae" / "config.json"
80
+ with open(config_path, "r") as f:
81
+ config = make_hashable_key(json.load(f))
82
+
83
+ assert config in diffusers_and_ours_config_mapping, (
84
+ "Provided diffusers checkpoint config for VAE is not suppported. "
85
+ "We only support diffusers configs found in Lightricks/LTX-Video."
86
+ )
87
+
88
+ config = diffusers_and_ours_config_mapping[config]
89
+
90
+ state_dict_path = (
91
+ pretrained_model_name_or_path
92
+ / "vae"
93
+ / "diffusion_pytorch_model.safetensors"
94
+ )
95
+
96
+ state_dict = {}
97
+ with safe_open(state_dict_path, framework="pt", device="cpu") as f:
98
+ for k in f.keys():
99
+ state_dict[k] = f.get_tensor(k)
100
+ for key in list(state_dict.keys()):
101
+ new_key = key
102
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
103
+ new_key = new_key.replace(replace_key, rename_key)
104
+
105
+ state_dict[new_key] = state_dict.pop(key)
106
+
107
+ elif pretrained_model_name_or_path.is_file() and str(
108
+ pretrained_model_name_or_path
109
+ ).endswith(".safetensors"):
110
+ state_dict = {}
111
+ with safe_open(
112
+ pretrained_model_name_or_path, framework="pt", device="cpu"
113
+ ) as f:
114
+ metadata = f.metadata()
115
+ for k in f.keys():
116
+ state_dict[k] = f.get_tensor(k)
117
+ configs = json.loads(metadata["config"])
118
+ config = configs["vae"]
119
+
120
+ video_vae = cls.from_config(config)
121
+ if "torch_dtype" in kwargs:
122
+ video_vae.to(kwargs["torch_dtype"])
123
+ video_vae.load_state_dict(state_dict)
124
+ return video_vae
125
+
126
+ @staticmethod
127
+ def from_config(config):
128
+ assert (
129
+ config["_class_name"] == "CausalVideoAutoencoder"
130
+ ), "config must have _class_name=CausalVideoAutoencoder"
131
+ if isinstance(config["dims"], list):
132
+ config["dims"] = tuple(config["dims"])
133
+
134
+ assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
135
+
136
+ double_z = config.get("double_z", True)
137
+ latent_log_var = config.get(
138
+ "latent_log_var", "per_channel" if double_z else "none"
139
+ )
140
+ use_quant_conv = config.get("use_quant_conv", True)
141
+ normalize_latent_channels = config.get("normalize_latent_channels", False)
142
+
143
+ if use_quant_conv and latent_log_var in ["uniform", "constant"]:
144
+ raise ValueError(
145
+ f"latent_log_var={latent_log_var} requires use_quant_conv=False"
146
+ )
147
+
148
+ encoder = Encoder(
149
+ dims=config["dims"],
150
+ in_channels=config.get("in_channels", 3),
151
+ out_channels=config["latent_channels"],
152
+ blocks=config.get("encoder_blocks", config.get("blocks")),
153
+ patch_size=config.get("patch_size", 1),
154
+ latent_log_var=latent_log_var,
155
+ norm_layer=config.get("norm_layer", "group_norm"),
156
+ base_channels=config.get("encoder_base_channels", 128),
157
+ spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
158
+ )
159
+
160
+ decoder = Decoder(
161
+ dims=config["dims"],
162
+ in_channels=config["latent_channels"],
163
+ out_channels=config.get("out_channels", 3),
164
+ blocks=config.get("decoder_blocks", config.get("blocks")),
165
+ patch_size=config.get("patch_size", 1),
166
+ norm_layer=config.get("norm_layer", "group_norm"),
167
+ causal=config.get("causal_decoder", False),
168
+ timestep_conditioning=config.get("timestep_conditioning", False),
169
+ base_channels=config.get("decoder_base_channels", 128),
170
+ spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
171
+ )
172
+
173
+ dims = config["dims"]
174
+ return CausalVideoAutoencoder(
175
+ encoder=encoder,
176
+ decoder=decoder,
177
+ latent_channels=config["latent_channels"],
178
+ dims=dims,
179
+ use_quant_conv=use_quant_conv,
180
+ normalize_latent_channels=normalize_latent_channels,
181
+ )
182
+
183
+ @property
184
+ def config(self):
185
+ return SimpleNamespace(
186
+ _class_name="CausalVideoAutoencoder",
187
+ dims=self.dims,
188
+ in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
189
+ out_channels=self.decoder.conv_out.out_channels
190
+ // self.decoder.patch_size**2,
191
+ latent_channels=self.decoder.conv_in.in_channels,
192
+ encoder_blocks=self.encoder.blocks_desc,
193
+ decoder_blocks=self.decoder.blocks_desc,
194
+ scaling_factor=1.0,
195
+ norm_layer=self.encoder.norm_layer,
196
+ patch_size=self.encoder.patch_size,
197
+ latent_log_var=self.encoder.latent_log_var,
198
+ use_quant_conv=self.use_quant_conv,
199
+ causal_decoder=self.decoder.causal,
200
+ timestep_conditioning=self.decoder.timestep_conditioning,
201
+ normalize_latent_channels=self.normalize_latent_channels,
202
+ )
203
+
204
+ @property
205
+ def is_video_supported(self):
206
+ """
207
+ Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
208
+ """
209
+ return self.dims != 2
210
+
211
+ @property
212
+ def spatial_downscale_factor(self):
213
+ return (
214
+ 2
215
+ ** len(
216
+ [
217
+ block
218
+ for block in self.encoder.blocks_desc
219
+ if block[0]
220
+ in [
221
+ "compress_space",
222
+ "compress_all",
223
+ "compress_all_res",
224
+ "compress_space_res",
225
+ ]
226
+ ]
227
+ )
228
+ * self.encoder.patch_size
229
+ )
230
+
231
+ @property
232
+ def temporal_downscale_factor(self):
233
+ return 2 ** len(
234
+ [
235
+ block
236
+ for block in self.encoder.blocks_desc
237
+ if block[0]
238
+ in [
239
+ "compress_time",
240
+ "compress_all",
241
+ "compress_all_res",
242
+ "compress_time_res",
243
+ ]
244
+ ]
245
+ )
246
+
247
+ def to_json_string(self) -> str:
248
+ import json
249
+
250
+ return json.dumps(self.config.__dict__)
251
+
252
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
253
+ if any([key.startswith("vae.") for key in state_dict.keys()]):
254
+ state_dict = {
255
+ key.replace("vae.", ""): value
256
+ for key, value in state_dict.items()
257
+ if key.startswith("vae.")
258
+ }
259
+ ckpt_state_dict = {
260
+ key: value
261
+ for key, value in state_dict.items()
262
+ if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
263
+ }
264
+
265
+ model_keys = set(name for name, _ in self.named_modules())
266
+
267
+ key_mapping = {
268
+ ".resnets.": ".res_blocks.",
269
+ "downsamplers.0": "downsample",
270
+ "upsamplers.0": "upsample",
271
+ }
272
+ converted_state_dict = {}
273
+ for key, value in ckpt_state_dict.items():
274
+ for k, v in key_mapping.items():
275
+ key = key.replace(k, v)
276
+
277
+ key_prefix = ".".join(key.split(".")[:-1])
278
+ if "norm" in key and key_prefix not in model_keys:
279
+ logger.info(
280
+ f"Removing key {key} from state_dict as it is not present in the model"
281
+ )
282
+ continue
283
+
284
+ converted_state_dict[key] = value
285
+
286
+ super().load_state_dict(converted_state_dict, strict=strict)
287
+
288
+ data_dict = {
289
+ key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
290
+ for key, value in state_dict.items()
291
+ if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
292
+ }
293
+ if len(data_dict) > 0:
294
+ self.register_buffer("std_of_means", data_dict["std-of-means"])
295
+ self.register_buffer(
296
+ "mean_of_means",
297
+ data_dict.get(
298
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
299
+ ),
300
+ )
301
+
302
+ def last_layer(self):
303
+ if hasattr(self.decoder, "conv_out"):
304
+ if isinstance(self.decoder.conv_out, nn.Sequential):
305
+ last_layer = self.decoder.conv_out[-1]
306
+ else:
307
+ last_layer = self.decoder.conv_out
308
+ else:
309
+ last_layer = self.decoder.layers[-1]
310
+ return last_layer
311
+
312
+ def set_use_tpu_flash_attention(self):
313
+ for block in self.decoder.up_blocks:
314
+ if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
315
+ for attention_block in block.attention_blocks:
316
+ attention_block.set_use_tpu_flash_attention()
317
+
318
+
319
+ class Encoder(nn.Module):
320
+ r"""
321
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
322
+
323
+ Args:
324
+ dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
325
+ The number of dimensions to use in convolutions.
326
+ in_channels (`int`, *optional*, defaults to 3):
327
+ The number of input channels.
328
+ out_channels (`int`, *optional*, defaults to 3):
329
+ The number of output channels.
330
+ blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
331
+ The blocks to use. Each block is a tuple of the block name and the number of layers.
332
+ base_channels (`int`, *optional*, defaults to 128):
333
+ The number of output channels for the first convolutional layer.
334
+ norm_num_groups (`int`, *optional*, defaults to 32):
335
+ The number of groups for normalization.
336
+ patch_size (`int`, *optional*, defaults to 1):
337
+ The patch size to use. Should be a power of 2.
338
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
339
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
340
+ latent_log_var (`str`, *optional*, defaults to `per_channel`):
341
+ The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
342
+ """
343
+
344
+ def __init__(
345
+ self,
346
+ dims: Union[int, Tuple[int, int]] = 3,
347
+ in_channels: int = 3,
348
+ out_channels: int = 3,
349
+ blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
350
+ base_channels: int = 128,
351
+ norm_num_groups: int = 32,
352
+ patch_size: Union[int, Tuple[int]] = 1,
353
+ norm_layer: str = "group_norm", # group_norm, pixel_norm
354
+ latent_log_var: str = "per_channel",
355
+ spatial_padding_mode: str = "zeros",
356
+ ):
357
+ super().__init__()
358
+ self.patch_size = patch_size
359
+ self.norm_layer = norm_layer
360
+ self.latent_channels = out_channels
361
+ self.latent_log_var = latent_log_var
362
+ self.blocks_desc = blocks
363
+
364
+ in_channels = in_channels * patch_size**2
365
+ output_channel = base_channels
366
+
367
+ self.conv_in = make_conv_nd(
368
+ dims=dims,
369
+ in_channels=in_channels,
370
+ out_channels=output_channel,
371
+ kernel_size=3,
372
+ stride=1,
373
+ padding=1,
374
+ causal=True,
375
+ spatial_padding_mode=spatial_padding_mode,
376
+ )
377
+
378
+ self.down_blocks = nn.ModuleList([])
379
+
380
+ for block_name, block_params in blocks:
381
+ input_channel = output_channel
382
+ if isinstance(block_params, int):
383
+ block_params = {"num_layers": block_params}
384
+
385
+ if block_name == "res_x":
386
+ block = UNetMidBlock3D(
387
+ dims=dims,
388
+ in_channels=input_channel,
389
+ num_layers=block_params["num_layers"],
390
+ resnet_eps=1e-6,
391
+ resnet_groups=norm_num_groups,
392
+ norm_layer=norm_layer,
393
+ spatial_padding_mode=spatial_padding_mode,
394
+ )
395
+ elif block_name == "res_x_y":
396
+ output_channel = block_params.get("multiplier", 2) * output_channel
397
+ block = ResnetBlock3D(
398
+ dims=dims,
399
+ in_channels=input_channel,
400
+ out_channels=output_channel,
401
+ eps=1e-6,
402
+ groups=norm_num_groups,
403
+ norm_layer=norm_layer,
404
+ spatial_padding_mode=spatial_padding_mode,
405
+ )
406
+ elif block_name == "compress_time":
407
+ block = make_conv_nd(
408
+ dims=dims,
409
+ in_channels=input_channel,
410
+ out_channels=output_channel,
411
+ kernel_size=3,
412
+ stride=(2, 1, 1),
413
+ causal=True,
414
+ spatial_padding_mode=spatial_padding_mode,
415
+ )
416
+ elif block_name == "compress_space":
417
+ block = make_conv_nd(
418
+ dims=dims,
419
+ in_channels=input_channel,
420
+ out_channels=output_channel,
421
+ kernel_size=3,
422
+ stride=(1, 2, 2),
423
+ causal=True,
424
+ spatial_padding_mode=spatial_padding_mode,
425
+ )
426
+ elif block_name == "compress_all":
427
+ block = make_conv_nd(
428
+ dims=dims,
429
+ in_channels=input_channel,
430
+ out_channels=output_channel,
431
+ kernel_size=3,
432
+ stride=(2, 2, 2),
433
+ causal=True,
434
+ spatial_padding_mode=spatial_padding_mode,
435
+ )
436
+ elif block_name == "compress_all_x_y":
437
+ output_channel = block_params.get("multiplier", 2) * output_channel
438
+ block = make_conv_nd(
439
+ dims=dims,
440
+ in_channels=input_channel,
441
+ out_channels=output_channel,
442
+ kernel_size=3,
443
+ stride=(2, 2, 2),
444
+ causal=True,
445
+ spatial_padding_mode=spatial_padding_mode,
446
+ )
447
+ elif block_name == "compress_all_res":
448
+ output_channel = block_params.get("multiplier", 2) * output_channel
449
+ block = SpaceToDepthDownsample(
450
+ dims=dims,
451
+ in_channels=input_channel,
452
+ out_channels=output_channel,
453
+ stride=(2, 2, 2),
454
+ spatial_padding_mode=spatial_padding_mode,
455
+ )
456
+ elif block_name == "compress_space_res":
457
+ output_channel = block_params.get("multiplier", 2) * output_channel
458
+ block = SpaceToDepthDownsample(
459
+ dims=dims,
460
+ in_channels=input_channel,
461
+ out_channels=output_channel,
462
+ stride=(1, 2, 2),
463
+ spatial_padding_mode=spatial_padding_mode,
464
+ )
465
+ elif block_name == "compress_time_res":
466
+ output_channel = block_params.get("multiplier", 2) * output_channel
467
+ block = SpaceToDepthDownsample(
468
+ dims=dims,
469
+ in_channels=input_channel,
470
+ out_channels=output_channel,
471
+ stride=(2, 1, 1),
472
+ spatial_padding_mode=spatial_padding_mode,
473
+ )
474
+ else:
475
+ raise ValueError(f"unknown block: {block_name}")
476
+
477
+ self.down_blocks.append(block)
478
+
479
+ # out
480
+ if norm_layer == "group_norm":
481
+ self.conv_norm_out = nn.GroupNorm(
482
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
483
+ )
484
+ elif norm_layer == "pixel_norm":
485
+ self.conv_norm_out = PixelNorm()
486
+ elif norm_layer == "layer_norm":
487
+ self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
488
+
489
+ self.conv_act = nn.SiLU()
490
+
491
+ conv_out_channels = out_channels
492
+ if latent_log_var == "per_channel":
493
+ conv_out_channels *= 2
494
+ elif latent_log_var == "uniform":
495
+ conv_out_channels += 1
496
+ elif latent_log_var == "constant":
497
+ conv_out_channels += 1
498
+ elif latent_log_var != "none":
499
+ raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
500
+ self.conv_out = make_conv_nd(
501
+ dims,
502
+ output_channel,
503
+ conv_out_channels,
504
+ 3,
505
+ padding=1,
506
+ causal=True,
507
+ spatial_padding_mode=spatial_padding_mode,
508
+ )
509
+
510
+ self.gradient_checkpointing = False
511
+
512
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
513
+ r"""The forward method of the `Encoder` class."""
514
+
515
+ sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
516
+ sample = self.conv_in(sample)
517
+
518
+ checkpoint_fn = (
519
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
520
+ if self.gradient_checkpointing and self.training
521
+ else lambda x: x
522
+ )
523
+
524
+ for down_block in self.down_blocks:
525
+ sample = checkpoint_fn(down_block)(sample)
526
+
527
+ sample = self.conv_norm_out(sample)
528
+ sample = self.conv_act(sample)
529
+ sample = self.conv_out(sample)
530
+
531
+ if self.latent_log_var == "uniform":
532
+ last_channel = sample[:, -1:, ...]
533
+ num_dims = sample.dim()
534
+
535
+ if num_dims == 4:
536
+ # For shape (B, C, H, W)
537
+ repeated_last_channel = last_channel.repeat(
538
+ 1, sample.shape[1] - 2, 1, 1
539
+ )
540
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
541
+ elif num_dims == 5:
542
+ # For shape (B, C, F, H, W)
543
+ repeated_last_channel = last_channel.repeat(
544
+ 1, sample.shape[1] - 2, 1, 1, 1
545
+ )
546
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
547
+ else:
548
+ raise ValueError(f"Invalid input shape: {sample.shape}")
549
+ elif self.latent_log_var == "constant":
550
+ sample = sample[:, :-1, ...]
551
+ approx_ln_0 = (
552
+ -30
553
+ ) # this is the minimal clamp value in DiagonalGaussianDistribution objects
554
+ sample = torch.cat(
555
+ [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
556
+ dim=1,
557
+ )
558
+
559
+
560
+
561
+ # --- INÍCIO DO PATCH CIRÚRGICO ADUC ---
562
+ # Verificamos uma variável de ambiente para ligar/desligar o overlay
563
+
564
+ if os.getenv("ADUC_DEBUG_OVERLAY", "1") == "1":
565
+ try:
566
+ print(f"[ADUC DEBUG LTX *causal_video_autoencoder.py*]=======")
567
+ print(f"[sample] {sample.shape}")
568
+
569
+ # Converte B,C,F,H,W para F,H,W,C na CPU
570
+ video_np = (sample.clone().squeeze(0).permute(1, 2, 3, 0) * 127.5 + 127.5).byte().cpu().numpy()
571
+
572
+ try:
573
+ font = ImageFont.truetype("arial.ttf", 24)
574
+ except IOError:
575
+ font = ImageFont.load_default(size=24)
576
+
577
+ processed_frames = []
578
+ for i in range(video_np.shape[0]):
579
+ frame_pil = Image.fromarray(video_np[i])
580
+ draw = ImageDraw.Draw(frame_pil)
581
+
582
+ # Texto simples, já que não temos o contexto do fragmento aqui
583
+ text = f"F: {i}"
584
+ position = (10, frame_pil.height - 40)
585
+
586
+ # Contorno para legibilidade
587
+ draw.text((position[0]-1, position[1]-1), text, font=font, fill="black")
588
+ draw.text((position[0]+1, position[1]-1), text, font=font, fill="black")
589
+ draw.text((position[0]-1, position[1]+1), text, font=font, fill="black")
590
+ draw.text((position[0]+1, position[1]+1), text, font=font, fill="black")
591
+ draw.text(position, text, font=font, fill="white")
592
+
593
+ processed_frames.append(np.array(frame_pil))
594
+
595
+ # Converte de volta para tensor B,C,F,H,W no device original
596
+ processed_np = np.stack(processed_frames)
597
+ final_tensor = torch.from_numpy(processed_np).to(sample.device, dtype=torch.float32)
598
+ final_tensor = (final_tensor / 127.5) - 1.0
599
+ sample = final_tensor.permute(3, 0, 1, 2).unsqueeze(0) # F,H,W,C -> B,C,F,H,W
600
+ except Exception as e:
601
+ # Se algo der errado no patch, apenas loga o erro e continua com o tensor original
602
+ print(f"[ADUC_DEBUG_OVERLAY] Erro ao adicionar texto: {e}")
603
+ # --- FIM DO PATCH CIRÚRGICO ---
604
+
605
+ return sample
606
+
607
+
608
+ class Decoder(nn.Module):
609
+ r"""
610
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
611
+
612
+ Args:
613
+ dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
614
+ The number of dimensions to use in convolutions.
615
+ in_channels (`int`, *optional*, defaults to 3):
616
+ The number of input channels.
617
+ out_channels (`int`, *optional*, defaults to 3):
618
+ The number of output channels.
619
+ blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
620
+ The blocks to use. Each block is a tuple of the block name and the number of layers.
621
+ base_channels (`int`, *optional*, defaults to 128):
622
+ The number of output channels for the first convolutional layer.
623
+ norm_num_groups (`int`, *optional*, defaults to 32):
624
+ The number of groups for normalization.
625
+ patch_size (`int`, *optional*, defaults to 1):
626
+ The patch size to use. Should be a power of 2.
627
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
628
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
629
+ causal (`bool`, *optional*, defaults to `True`):
630
+ Whether to use causal convolutions or not.
631
+ """
632
+
633
+ def __init__(
634
+ self,
635
+ dims,
636
+ in_channels: int = 3,
637
+ out_channels: int = 3,
638
+ blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
639
+ base_channels: int = 128,
640
+ layers_per_block: int = 2,
641
+ norm_num_groups: int = 32,
642
+ patch_size: int = 1,
643
+ norm_layer: str = "group_norm",
644
+ causal: bool = True,
645
+ timestep_conditioning: bool = False,
646
+ spatial_padding_mode: str = "zeros",
647
+ ):
648
+ super().__init__()
649
+ self.patch_size = patch_size
650
+ self.layers_per_block = layers_per_block
651
+ out_channels = out_channels * patch_size**2
652
+ self.causal = causal
653
+ self.blocks_desc = blocks
654
+
655
+ # Compute output channel to be product of all channel-multiplier blocks
656
+ output_channel = base_channels
657
+ for block_name, block_params in list(reversed(blocks)):
658
+ block_params = block_params if isinstance(block_params, dict) else {}
659
+ if block_name == "res_x_y":
660
+ output_channel = output_channel * block_params.get("multiplier", 2)
661
+ if block_name.startswith("compress"):
662
+ output_channel = output_channel * block_params.get("multiplier", 1)
663
+
664
+ self.conv_in = make_conv_nd(
665
+ dims,
666
+ in_channels,
667
+ output_channel,
668
+ kernel_size=3,
669
+ stride=1,
670
+ padding=1,
671
+ causal=True,
672
+ spatial_padding_mode=spatial_padding_mode,
673
+ )
674
+
675
+ self.up_blocks = nn.ModuleList([])
676
+
677
+ for block_name, block_params in list(reversed(blocks)):
678
+ input_channel = output_channel
679
+ if isinstance(block_params, int):
680
+ block_params = {"num_layers": block_params}
681
+
682
+ if block_name == "res_x":
683
+ block = UNetMidBlock3D(
684
+ dims=dims,
685
+ in_channels=input_channel,
686
+ num_layers=block_params["num_layers"],
687
+ resnet_eps=1e-6,
688
+ resnet_groups=norm_num_groups,
689
+ norm_layer=norm_layer,
690
+ inject_noise=block_params.get("inject_noise", False),
691
+ timestep_conditioning=timestep_conditioning,
692
+ spatial_padding_mode=spatial_padding_mode,
693
+ )
694
+ elif block_name == "attn_res_x":
695
+ block = UNetMidBlock3D(
696
+ dims=dims,
697
+ in_channels=input_channel,
698
+ num_layers=block_params["num_layers"],
699
+ resnet_groups=norm_num_groups,
700
+ norm_layer=norm_layer,
701
+ inject_noise=block_params.get("inject_noise", False),
702
+ timestep_conditioning=timestep_conditioning,
703
+ attention_head_dim=block_params["attention_head_dim"],
704
+ spatial_padding_mode=spatial_padding_mode,
705
+ )
706
+ elif block_name == "res_x_y":
707
+ output_channel = output_channel // block_params.get("multiplier", 2)
708
+ block = ResnetBlock3D(
709
+ dims=dims,
710
+ in_channels=input_channel,
711
+ out_channels=output_channel,
712
+ eps=1e-6,
713
+ groups=norm_num_groups,
714
+ norm_layer=norm_layer,
715
+ inject_noise=block_params.get("inject_noise", False),
716
+ timestep_conditioning=False,
717
+ spatial_padding_mode=spatial_padding_mode,
718
+ )
719
+ elif block_name == "compress_time":
720
+ block = DepthToSpaceUpsample(
721
+ dims=dims,
722
+ in_channels=input_channel,
723
+ stride=(2, 1, 1),
724
+ spatial_padding_mode=spatial_padding_mode,
725
+ )
726
+ elif block_name == "compress_space":
727
+ block = DepthToSpaceUpsample(
728
+ dims=dims,
729
+ in_channels=input_channel,
730
+ stride=(1, 2, 2),
731
+ spatial_padding_mode=spatial_padding_mode,
732
+ )
733
+ elif block_name == "compress_all":
734
+ output_channel = output_channel // block_params.get("multiplier", 1)
735
+ block = DepthToSpaceUpsample(
736
+ dims=dims,
737
+ in_channels=input_channel,
738
+ stride=(2, 2, 2),
739
+ residual=block_params.get("residual", False),
740
+ out_channels_reduction_factor=block_params.get("multiplier", 1),
741
+ spatial_padding_mode=spatial_padding_mode,
742
+ )
743
+ else:
744
+ raise ValueError(f"unknown layer: {block_name}")
745
+
746
+ self.up_blocks.append(block)
747
+
748
+ if norm_layer == "group_norm":
749
+ self.conv_norm_out = nn.GroupNorm(
750
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
751
+ )
752
+ elif norm_layer == "pixel_norm":
753
+ self.conv_norm_out = PixelNorm()
754
+ elif norm_layer == "layer_norm":
755
+ self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
756
+
757
+ self.conv_act = nn.SiLU()
758
+ self.conv_out = make_conv_nd(
759
+ dims,
760
+ output_channel,
761
+ out_channels,
762
+ 3,
763
+ padding=1,
764
+ causal=True,
765
+ spatial_padding_mode=spatial_padding_mode,
766
+ )
767
+
768
+ self.gradient_checkpointing = False
769
+
770
+ self.timestep_conditioning = timestep_conditioning
771
+
772
+ if timestep_conditioning:
773
+ self.timestep_scale_multiplier = nn.Parameter(
774
+ torch.tensor(1000.0, dtype=torch.float32)
775
+ )
776
+ self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
777
+ output_channel * 2, 0
778
+ )
779
+ self.last_scale_shift_table = nn.Parameter(
780
+ torch.randn(2, output_channel) / output_channel**0.5
781
+ )
782
+
783
+ def forward(
784
+ self,
785
+ sample: torch.FloatTensor,
786
+ target_shape,
787
+ timestep: Optional[torch.Tensor] = None,
788
+ ) -> torch.FloatTensor:
789
+ r"""The forward method of the `Decoder` class."""
790
+ assert target_shape is not None, "target_shape must be provided"
791
+ batch_size = sample.shape[0]
792
+
793
+ sample = self.conv_in(sample, causal=self.causal)
794
+
795
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
796
+
797
+ checkpoint_fn = (
798
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
799
+ if self.gradient_checkpointing and self.training
800
+ else lambda x: x
801
+ )
802
+
803
+ sample = sample.to(upscale_dtype)
804
+
805
+ if self.timestep_conditioning:
806
+ assert (
807
+ timestep is not None
808
+ ), "should pass timestep with timestep_conditioning=True"
809
+ scaled_timestep = timestep * self.timestep_scale_multiplier
810
+
811
+ for up_block in self.up_blocks:
812
+ if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
813
+ sample = checkpoint_fn(up_block)(
814
+ sample, causal=self.causal, timestep=scaled_timestep
815
+ )
816
+ else:
817
+ sample = checkpoint_fn(up_block)(sample, causal=self.causal)
818
+
819
+ sample = self.conv_norm_out(sample)
820
+
821
+ if self.timestep_conditioning:
822
+ embedded_timestep = self.last_time_embedder(
823
+ timestep=scaled_timestep.flatten(),
824
+ resolution=None,
825
+ aspect_ratio=None,
826
+ batch_size=sample.shape[0],
827
+ hidden_dtype=sample.dtype,
828
+ )
829
+ embedded_timestep = embedded_timestep.view(
830
+ batch_size, embedded_timestep.shape[-1], 1, 1, 1
831
+ )
832
+ ada_values = self.last_scale_shift_table[
833
+ None, ..., None, None, None
834
+ ] + embedded_timestep.reshape(
835
+ batch_size,
836
+ 2,
837
+ -1,
838
+ embedded_timestep.shape[-3],
839
+ embedded_timestep.shape[-2],
840
+ embedded_timestep.shape[-1],
841
+ )
842
+ shift, scale = ada_values.unbind(dim=1)
843
+ sample = sample * (1 + scale) + shift
844
+
845
+ sample = self.conv_act(sample)
846
+ sample = self.conv_out(sample, causal=self.causal)
847
+
848
+ sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
849
+
850
+ return sample
851
+
852
+
853
+ class UNetMidBlock3D(nn.Module):
854
+ """
855
+ A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
856
+
857
+ Args:
858
+ in_channels (`int`): The number of input channels.
859
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
860
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
861
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
862
+ resnet_groups (`int`, *optional*, defaults to 32):
863
+ The number of groups to use in the group normalization layers of the resnet blocks.
864
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
865
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
866
+ inject_noise (`bool`, *optional*, defaults to `False`):
867
+ Whether to inject noise into the hidden states.
868
+ timestep_conditioning (`bool`, *optional*, defaults to `False`):
869
+ Whether to condition the hidden states on the timestep.
870
+ attention_head_dim (`int`, *optional*, defaults to -1):
871
+ The dimension of the attention head. If -1, no attention is used.
872
+
873
+ Returns:
874
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
875
+ in_channels, height, width)`.
876
+
877
+ """
878
+
879
+ def __init__(
880
+ self,
881
+ dims: Union[int, Tuple[int, int]],
882
+ in_channels: int,
883
+ dropout: float = 0.0,
884
+ num_layers: int = 1,
885
+ resnet_eps: float = 1e-6,
886
+ resnet_groups: int = 32,
887
+ norm_layer: str = "group_norm",
888
+ inject_noise: bool = False,
889
+ timestep_conditioning: bool = False,
890
+ attention_head_dim: int = -1,
891
+ spatial_padding_mode: str = "zeros",
892
+ ):
893
+ super().__init__()
894
+ resnet_groups = (
895
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
896
+ )
897
+ self.timestep_conditioning = timestep_conditioning
898
+
899
+ if timestep_conditioning:
900
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
901
+ in_channels * 4, 0
902
+ )
903
+
904
+ self.res_blocks = nn.ModuleList(
905
+ [
906
+ ResnetBlock3D(
907
+ dims=dims,
908
+ in_channels=in_channels,
909
+ out_channels=in_channels,
910
+ eps=resnet_eps,
911
+ groups=resnet_groups,
912
+ dropout=dropout,
913
+ norm_layer=norm_layer,
914
+ inject_noise=inject_noise,
915
+ timestep_conditioning=timestep_conditioning,
916
+ spatial_padding_mode=spatial_padding_mode,
917
+ )
918
+ for _ in range(num_layers)
919
+ ]
920
+ )
921
+
922
+ self.attention_blocks = None
923
+
924
+ if attention_head_dim > 0:
925
+ if attention_head_dim > in_channels:
926
+ raise ValueError(
927
+ "attention_head_dim must be less than or equal to in_channels"
928
+ )
929
+
930
+ self.attention_blocks = nn.ModuleList(
931
+ [
932
+ Attention(
933
+ query_dim=in_channels,
934
+ heads=in_channels // attention_head_dim,
935
+ dim_head=attention_head_dim,
936
+ bias=True,
937
+ out_bias=True,
938
+ qk_norm="rms_norm",
939
+ residual_connection=True,
940
+ )
941
+ for _ in range(num_layers)
942
+ ]
943
+ )
944
+
945
+ def forward(
946
+ self,
947
+ hidden_states: torch.FloatTensor,
948
+ causal: bool = True,
949
+ timestep: Optional[torch.Tensor] = None,
950
+ ) -> torch.FloatTensor:
951
+ timestep_embed = None
952
+ if self.timestep_conditioning:
953
+ assert (
954
+ timestep is not None
955
+ ), "should pass timestep with timestep_conditioning=True"
956
+ batch_size = hidden_states.shape[0]
957
+ timestep_embed = self.time_embedder(
958
+ timestep=timestep.flatten(),
959
+ resolution=None,
960
+ aspect_ratio=None,
961
+ batch_size=batch_size,
962
+ hidden_dtype=hidden_states.dtype,
963
+ )
964
+ timestep_embed = timestep_embed.view(
965
+ batch_size, timestep_embed.shape[-1], 1, 1, 1
966
+ )
967
+
968
+ if self.attention_blocks:
969
+ for resnet, attention in zip(self.res_blocks, self.attention_blocks):
970
+ hidden_states = resnet(
971
+ hidden_states, causal=causal, timestep=timestep_embed
972
+ )
973
+
974
+ # Reshape the hidden states to be (batch_size, frames * height * width, channel)
975
+ batch_size, channel, frames, height, width = hidden_states.shape
976
+ hidden_states = hidden_states.view(
977
+ batch_size, channel, frames * height * width
978
+ ).transpose(1, 2)
979
+
980
+ if attention.use_tpu_flash_attention:
981
+ # Pad the second dimension to be divisible by block_k_major (block in flash attention)
982
+ seq_len = hidden_states.shape[1]
983
+ block_k_major = 512
984
+ pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
985
+ if pad_len > 0:
986
+ hidden_states = F.pad(
987
+ hidden_states, (0, 0, 0, pad_len), "constant", 0
988
+ )
989
+
990
+ # Create a mask with ones for the original sequence length and zeros for the padded indexes
991
+ mask = torch.ones(
992
+ (hidden_states.shape[0], seq_len),
993
+ device=hidden_states.device,
994
+ dtype=hidden_states.dtype,
995
+ )
996
+ if pad_len > 0:
997
+ mask = F.pad(mask, (0, pad_len), "constant", 0)
998
+
999
+ hidden_states = attention(
1000
+ hidden_states,
1001
+ attention_mask=(
1002
+ None if not attention.use_tpu_flash_attention else mask
1003
+ ),
1004
+ )
1005
+
1006
+ if attention.use_tpu_flash_attention:
1007
+ # Remove the padding
1008
+ if pad_len > 0:
1009
+ hidden_states = hidden_states[:, :-pad_len, :]
1010
+
1011
+ # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
1012
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
1013
+ batch_size, channel, frames, height, width
1014
+ )
1015
+ else:
1016
+ for resnet in self.res_blocks:
1017
+ hidden_states = resnet(
1018
+ hidden_states, causal=causal, timestep=timestep_embed
1019
+ )
1020
+
1021
+ return hidden_states
1022
+
1023
+
1024
+ class SpaceToDepthDownsample(nn.Module):
1025
+ def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
1026
+ super().__init__()
1027
+ self.stride = stride
1028
+ self.group_size = in_channels * np.prod(stride) // out_channels
1029
+ self.conv = make_conv_nd(
1030
+ dims=dims,
1031
+ in_channels=in_channels,
1032
+ out_channels=out_channels // np.prod(stride),
1033
+ kernel_size=3,
1034
+ stride=1,
1035
+ causal=True,
1036
+ spatial_padding_mode=spatial_padding_mode,
1037
+ )
1038
+
1039
+ def forward(self, x, causal: bool = True):
1040
+ if self.stride[0] == 2:
1041
+ x = torch.cat(
1042
+ [x[:, :, :1, :, :], x], dim=2
1043
+ ) # duplicate first frames for padding
1044
+
1045
+ # skip connection
1046
+ x_in = rearrange(
1047
+ x,
1048
+ "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
1049
+ p1=self.stride[0],
1050
+ p2=self.stride[1],
1051
+ p3=self.stride[2],
1052
+ )
1053
+ x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
1054
+ x_in = x_in.mean(dim=2)
1055
+
1056
+ # conv
1057
+ x = self.conv(x, causal=causal)
1058
+ x = rearrange(
1059
+ x,
1060
+ "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
1061
+ p1=self.stride[0],
1062
+ p2=self.stride[1],
1063
+ p3=self.stride[2],
1064
+ )
1065
+
1066
+ x = x + x_in
1067
+
1068
+ return x
1069
+
1070
+
1071
+ class DepthToSpaceUpsample(nn.Module):
1072
+ def __init__(
1073
+ self,
1074
+ dims,
1075
+ in_channels,
1076
+ stride,
1077
+ residual=False,
1078
+ out_channels_reduction_factor=1,
1079
+ spatial_padding_mode="zeros",
1080
+ ):
1081
+ super().__init__()
1082
+ self.stride = stride
1083
+ self.out_channels = (
1084
+ np.prod(stride) * in_channels // out_channels_reduction_factor
1085
+ )
1086
+ self.conv = make_conv_nd(
1087
+ dims=dims,
1088
+ in_channels=in_channels,
1089
+ out_channels=self.out_channels,
1090
+ kernel_size=3,
1091
+ stride=1,
1092
+ causal=True,
1093
+ spatial_padding_mode=spatial_padding_mode,
1094
+ )
1095
+ self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride)
1096
+ self.residual = residual
1097
+ self.out_channels_reduction_factor = out_channels_reduction_factor
1098
+
1099
+ def forward(self, x, causal: bool = True):
1100
+ if self.residual:
1101
+ # Reshape and duplicate the input to match the output shape
1102
+ x_in = self.pixel_shuffle(x)
1103
+ num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
1104
+ x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
1105
+ if self.stride[0] == 2:
1106
+ x_in = x_in[:, :, 1:, :, :]
1107
+ x = self.conv(x, causal=causal)
1108
+ x = self.pixel_shuffle(x)
1109
+ if self.stride[0] == 2:
1110
+ x = x[:, :, 1:, :, :]
1111
+ if self.residual:
1112
+ x = x + x_in
1113
+ return x
1114
+
1115
+
1116
+ class LayerNorm(nn.Module):
1117
+ def __init__(self, dim, eps, elementwise_affine=True) -> None:
1118
+ super().__init__()
1119
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
1120
+
1121
+ def forward(self, x):
1122
+ x = rearrange(x, "b c d h w -> b d h w c")
1123
+ x = self.norm(x)
1124
+ x = rearrange(x, "b d h w c -> b c d h w")
1125
+ return x
1126
+
1127
+
1128
+ class ResnetBlock3D(nn.Module):
1129
+ r"""
1130
+ A Resnet block.
1131
+
1132
+ Parameters:
1133
+ in_channels (`int`): The number of channels in the input.
1134
+ out_channels (`int`, *optional*, default to be `None`):
1135
+ The number of output channels for the first conv layer. If None, same as `in_channels`.
1136
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
1137
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
1138
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
1139
+ """
1140
+
1141
+ def __init__(
1142
+ self,
1143
+ dims: Union[int, Tuple[int, int]],
1144
+ in_channels: int,
1145
+ out_channels: Optional[int] = None,
1146
+ dropout: float = 0.0,
1147
+ groups: int = 32,
1148
+ eps: float = 1e-6,
1149
+ norm_layer: str = "group_norm",
1150
+ inject_noise: bool = False,
1151
+ timestep_conditioning: bool = False,
1152
+ spatial_padding_mode: str = "zeros",
1153
+ ):
1154
+ super().__init__()
1155
+ self.in_channels = in_channels
1156
+ out_channels = in_channels if out_channels is None else out_channels
1157
+ self.out_channels = out_channels
1158
+ self.inject_noise = inject_noise
1159
+
1160
+ if norm_layer == "group_norm":
1161
+ self.norm1 = nn.GroupNorm(
1162
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
1163
+ )
1164
+ elif norm_layer == "pixel_norm":
1165
+ self.norm1 = PixelNorm()
1166
+ elif norm_layer == "layer_norm":
1167
+ self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
1168
+
1169
+ self.non_linearity = nn.SiLU()
1170
+
1171
+ self.conv1 = make_conv_nd(
1172
+ dims,
1173
+ in_channels,
1174
+ out_channels,
1175
+ kernel_size=3,
1176
+ stride=1,
1177
+ padding=1,
1178
+ causal=True,
1179
+ spatial_padding_mode=spatial_padding_mode,
1180
+ )
1181
+
1182
+ if inject_noise:
1183
+ self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
1184
+
1185
+ if norm_layer == "group_norm":
1186
+ self.norm2 = nn.GroupNorm(
1187
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
1188
+ )
1189
+ elif norm_layer == "pixel_norm":
1190
+ self.norm2 = PixelNorm()
1191
+ elif norm_layer == "layer_norm":
1192
+ self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
1193
+
1194
+ self.dropout = torch.nn.Dropout(dropout)
1195
+
1196
+ self.conv2 = make_conv_nd(
1197
+ dims,
1198
+ out_channels,
1199
+ out_channels,
1200
+ kernel_size=3,
1201
+ stride=1,
1202
+ padding=1,
1203
+ causal=True,
1204
+ spatial_padding_mode=spatial_padding_mode,
1205
+ )
1206
+
1207
+ if inject_noise:
1208
+ self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
1209
+
1210
+ self.conv_shortcut = (
1211
+ make_linear_nd(
1212
+ dims=dims, in_channels=in_channels, out_channels=out_channels
1213
+ )
1214
+ if in_channels != out_channels
1215
+ else nn.Identity()
1216
+ )
1217
+
1218
+ self.norm3 = (
1219
+ LayerNorm(in_channels, eps=eps, elementwise_affine=True)
1220
+ if in_channels != out_channels
1221
+ else nn.Identity()
1222
+ )
1223
+
1224
+ self.timestep_conditioning = timestep_conditioning
1225
+
1226
+ if timestep_conditioning:
1227
+ self.scale_shift_table = nn.Parameter(
1228
+ torch.randn(4, in_channels) / in_channels**0.5
1229
+ )
1230
+
1231
+ def _feed_spatial_noise(
1232
+ self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
1233
+ ) -> torch.FloatTensor:
1234
+ spatial_shape = hidden_states.shape[-2:]
1235
+ device = hidden_states.device
1236
+ dtype = hidden_states.dtype
1237
+
1238
+ # similar to the "explicit noise inputs" method in style-gan
1239
+ spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
1240
+ scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
1241
+ hidden_states = hidden_states + scaled_noise
1242
+
1243
+ return hidden_states
1244
+
1245
+ def forward(
1246
+ self,
1247
+ input_tensor: torch.FloatTensor,
1248
+ causal: bool = True,
1249
+ timestep: Optional[torch.Tensor] = None,
1250
+ ) -> torch.FloatTensor:
1251
+ hidden_states = input_tensor
1252
+ batch_size = hidden_states.shape[0]
1253
+
1254
+ hidden_states = self.norm1(hidden_states)
1255
+ if self.timestep_conditioning:
1256
+ assert (
1257
+ timestep is not None
1258
+ ), "should pass timestep with timestep_conditioning=True"
1259
+ ada_values = self.scale_shift_table[
1260
+ None, ..., None, None, None
1261
+ ] + timestep.reshape(
1262
+ batch_size,
1263
+ 4,
1264
+ -1,
1265
+ timestep.shape[-3],
1266
+ timestep.shape[-2],
1267
+ timestep.shape[-1],
1268
+ )
1269
+ shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
1270
+
1271
+ hidden_states = hidden_states * (1 + scale1) + shift1
1272
+
1273
+ hidden_states = self.non_linearity(hidden_states)
1274
+
1275
+ hidden_states = self.conv1(hidden_states, causal=causal)
1276
+
1277
+ if self.inject_noise:
1278
+ hidden_states = self._feed_spatial_noise(
1279
+ hidden_states, self.per_channel_scale1
1280
+ )
1281
+
1282
+ hidden_states = self.norm2(hidden_states)
1283
+
1284
+ if self.timestep_conditioning:
1285
+ hidden_states = hidden_states * (1 + scale2) + shift2
1286
+
1287
+ hidden_states = self.non_linearity(hidden_states)
1288
+
1289
+ hidden_states = self.dropout(hidden_states)
1290
+
1291
+ hidden_states = self.conv2(hidden_states, causal=causal)
1292
+
1293
+ if self.inject_noise:
1294
+ hidden_states = self._feed_spatial_noise(
1295
+ hidden_states, self.per_channel_scale2
1296
+ )
1297
+
1298
+ input_tensor = self.norm3(input_tensor)
1299
+
1300
+ batch_size = input_tensor.shape[0]
1301
+
1302
+ input_tensor = self.conv_shortcut(input_tensor)
1303
+
1304
+ output_tensor = input_tensor + hidden_states
1305
+
1306
+ return output_tensor
1307
+
1308
+
1309
+ def patchify(x, patch_size_hw, patch_size_t=1):
1310
+ if patch_size_hw == 1 and patch_size_t == 1:
1311
+ return x
1312
+ if x.dim() == 4:
1313
+ x = rearrange(
1314
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
1315
+ )
1316
+ elif x.dim() == 5:
1317
+ x = rearrange(
1318
+ x,
1319
+ "b c (f p) (h q) (w r) -> b (c p r q) f h w",
1320
+ p=patch_size_t,
1321
+ q=patch_size_hw,
1322
+ r=patch_size_hw,
1323
+ )
1324
+ else:
1325
+ raise ValueError(f"Invalid input shape: {x.shape}")
1326
+
1327
+ return x
1328
+
1329
+
1330
+ def unpatchify(x, patch_size_hw, patch_size_t=1):
1331
+ if patch_size_hw == 1 and patch_size_t == 1:
1332
+ return x
1333
+
1334
+ if x.dim() == 4:
1335
+ x = rearrange(
1336
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
1337
+ )
1338
+ elif x.dim() == 5:
1339
+ x = rearrange(
1340
+ x,
1341
+ "b (c p r q) f h w -> b c (f p) (h q) (w r)",
1342
+ p=patch_size_t,
1343
+ q=patch_size_hw,
1344
+ r=patch_size_hw,
1345
+ )
1346
+
1347
+ return x
1348
+
1349
+
1350
+ def create_video_autoencoder_demo_config(
1351
+ latent_channels: int = 64,
1352
+ ):
1353
+ encoder_blocks = [
1354
+ ("res_x", {"num_layers": 2}),
1355
+ ("compress_space_res", {"multiplier": 2}),
1356
+ ("compress_time_res", {"multiplier": 2}),
1357
+ ("compress_all_res", {"multiplier": 2}),
1358
+ ("compress_all_res", {"multiplier": 2}),
1359
+ ("res_x", {"num_layers": 1}),
1360
+ ]
1361
+ decoder_blocks = [
1362
+ ("res_x", {"num_layers": 2, "inject_noise": False}),
1363
+ ("compress_all", {"residual": True, "multiplier": 2}),
1364
+ ("compress_all", {"residual": True, "multiplier": 2}),
1365
+ ("compress_all", {"residual": True, "multiplier": 2}),
1366
+ ("res_x", {"num_layers": 2, "inject_noise": False}),
1367
+ ]
1368
+ return {
1369
+ "_class_name": "CausalVideoAutoencoder",
1370
+ "dims": 3,
1371
+ "encoder_blocks": encoder_blocks,
1372
+ "decoder_blocks": decoder_blocks,
1373
+ "latent_channels": latent_channels,
1374
+ "norm_layer": "pixel_norm",
1375
+ "patch_size": 4,
1376
+ "latent_log_var": "uniform",
1377
+ "use_quant_conv": False,
1378
+ "causal_decoder": False,
1379
+ "timestep_conditioning": True,
1380
+ "spatial_padding_mode": "replicate",
1381
+ }
1382
+
1383
+
1384
+ def test_vae_patchify_unpatchify():
1385
+ import torch
1386
+
1387
+ x = torch.randn(2, 3, 8, 64, 64)
1388
+ x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
1389
+ x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
1390
+ assert torch.allclose(x, x_unpatched)
1391
+
1392
+
1393
+ def demo_video_autoencoder_forward_backward():
1394
+ # Configuration for the VideoAutoencoder
1395
+ config = create_video_autoencoder_demo_config()
1396
+
1397
+ # Instantiate the VideoAutoencoder with the specified configuration
1398
+ video_autoencoder = CausalVideoAutoencoder.from_config(config)
1399
+
1400
+ print(video_autoencoder)
1401
+ video_autoencoder.eval()
1402
+ # Print the total number of parameters in the video autoencoder
1403
+ total_params = sum(p.numel() for p in video_autoencoder.parameters())
1404
+ print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
1405
+
1406
+ # Create a mock input tensor simulating a batch of videos
1407
+ # Shape: (batch_size, channels, depth, height, width)
1408
+ # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
1409
+ input_videos = torch.randn(2, 3, 17, 64, 64)
1410
+
1411
+ # Forward pass: encode and decode the input videos
1412
+ latent = video_autoencoder.encode(input_videos).latent_dist.mode()
1413
+ print(f"input shape={input_videos.shape}")
1414
+ print(f"latent shape={latent.shape}")
1415
+
1416
+ timestep = torch.ones(input_videos.shape[0]) * 0.1
1417
+ reconstructed_videos = video_autoencoder.decode(
1418
+ latent, target_shape=input_videos.shape, timestep=timestep
1419
+ ).sample
1420
+
1421
+ print(f"reconstructed shape={reconstructed_videos.shape}")
1422
+
1423
+ # Validate that single image gets treated the same way as first frame
1424
+ input_image = input_videos[:, :, :1, :, :]
1425
+ image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
1426
+ _ = video_autoencoder.decode(
1427
+ image_latent, target_shape=image_latent.shape, timestep=timestep
1428
+ ).sample
1429
+
1430
+ first_frame_latent = latent[:, :, :1, :, :]
1431
+
1432
+ assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
1433
+ # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
1434
+ # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
1435
+ # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
1436
+
1437
+ # Calculate the loss (e.g., mean squared error)
1438
+ loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
1439
+
1440
+ # Perform backward pass
1441
+ loss.backward()
1442
+
1443
+ print(f"Demo completed with loss: {loss.item()}")
1444
+
1445
+
1446
+ # Ensure to call the demo function to execute the forward and backward pass
1447
+ if __name__ == "__main__":
1448
+ demo_video_autoencoder_forward_backward()