# FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido) import torch import contextlib import logging import sys from pathlib import Path import os import io LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") def add_deps_to_path(): """ Adiciona o diretório do repositório LTX ao sys.path para garantir que suas bibliotecas possam ser importadas. """ repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) if repo_path not in sys.path: sys.path.insert(0, repo_path) logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}") # Executa a função imediatamente para configurar o ambiente antes de qualquer importação. add_deps_to_path() # --- IMPORTAÇÃO CRÍTICA --- # Importa a função helper oficial da biblioteca LTX para decodificação. try: from ltx_video.models.autoencoders.vae_encode import vae_decode except ImportError: raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.") class _SimpleVAEManager: """ Manages VAE decoding, now using the official 'vae_decode' helper function for maximum compatibility. """ def __init__(self): self.pipeline = None self.device = torch.device("cpu") self.autocast_dtype = torch.float32 def attach_pipeline(self, pipeline, device=None, autocast_dtype=None): self.pipeline = pipeline if device is not None: self.device = torch.device(device) logging.info(f"[VAEManager] VAE device successfully set to: {self.device}") if autocast_dtype is not None: self.autocast_dtype = autocast_dtype @torch.no_grad() def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: """ Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper. """ if self.pipeline is None: raise RuntimeError("VAEManager: No pipeline has been attached.") # Move os latentes para o dispositivo VAE dedicado. latent_tensor_on_vae_device = latent_tensor.to(self.device) # Prepara o tensor de timesteps no mesmo dispositivo. num_items_in_batch = latent_tensor_on_vae_device.shape[0] timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device) autocast_device_type = self.device.type ctx = torch.autocast( device_type=autocast_device_type, dtype=self.autocast_dtype, enabled=(autocast_device_type == 'cuda') ) with ctx: logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.") # --- CORREÇÃO PRINCIPAL --- # Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente. # Esta função sabe como lidar com o argumento 'timestep'. pixels = vae_decode( latents=latent_tensor_on_vae_device, vae=self.pipeline.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True, # Importante manter este parâmetro consistente ) # A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal. pixels = pixels.clamp(0, 1) logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.") return pixels.cpu() # Singleton global vae_manager_singleton = _SimpleVAEManager()