Spaces:
Paused
Paused
| # FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido) | |
| import torch | |
| import contextlib | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import os | |
| import io | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| def add_deps_to_path(): | |
| """ | |
| Adiciona o diretório do repositório LTX ao sys.path para garantir que suas | |
| bibliotecas possam ser importadas. | |
| """ | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}") | |
| # Executa a função imediatamente para configurar o ambiente antes de qualquer importação. | |
| add_deps_to_path() | |
| # --- IMPORTAÇÃO CRÍTICA --- | |
| # Importa a função helper oficial da biblioteca LTX para decodificação. | |
| try: | |
| from ltx_video.models.autoencoders.vae_encode import vae_decode | |
| except ImportError: | |
| raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.") | |
| class _SimpleVAEManager: | |
| """ | |
| Manages VAE decoding, now using the official 'vae_decode' helper function | |
| for maximum compatibility. | |
| """ | |
| def __init__(self): | |
| self.pipeline = None | |
| self.device = torch.device("cpu") | |
| self.autocast_dtype = torch.float32 | |
| def attach_pipeline(self, pipeline, device=None, autocast_dtype=None): | |
| self.pipeline = pipeline | |
| if device is not None: | |
| self.device = torch.device(device) | |
| logging.info(f"[VAEManager] VAE device successfully set to: {self.device}") | |
| if autocast_dtype is not None: | |
| self.autocast_dtype = autocast_dtype | |
| def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: | |
| """ | |
| Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper. | |
| """ | |
| if self.pipeline is None: | |
| raise RuntimeError("VAEManager: No pipeline has been attached.") | |
| # Move os latentes para o dispositivo VAE dedicado. | |
| latent_tensor_on_vae_device = latent_tensor.to(self.device) | |
| # Prepara o tensor de timesteps no mesmo dispositivo. | |
| num_items_in_batch = latent_tensor_on_vae_device.shape[0] | |
| timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device) | |
| autocast_device_type = self.device.type | |
| ctx = torch.autocast( | |
| device_type=autocast_device_type, | |
| dtype=self.autocast_dtype, | |
| enabled=(autocast_device_type == 'cuda') | |
| ) | |
| with ctx: | |
| logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.") | |
| # --- CORREÇÃO PRINCIPAL --- | |
| # Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente. | |
| # Esta função sabe como lidar com o argumento 'timestep'. | |
| pixels = vae_decode( | |
| latents=latent_tensor_on_vae_device, | |
| vae=self.pipeline.vae, | |
| is_video=True, | |
| timestep=timestep_tensor, | |
| vae_per_channel_normalize=True, # Importante manter este parâmetro consistente | |
| ) | |
| # A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal. | |
| pixels = pixels.clamp(0, 1) | |
| logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.") | |
| return pixels.cpu() | |
| # Singleton global | |
| vae_manager_singleton = _SimpleVAEManager() |