Test / managers /vae_manager.py
eeuuia's picture
Update managers/vae_manager.py
28859c6 verified
raw
history blame
3.62 kB
# FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido)
import torch
import contextlib
import logging
import sys
from pathlib import Path
import os
import io
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
"""
Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
bibliotecas possam ser importadas.
"""
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
# Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
add_deps_to_path()
# --- IMPORTAÇÃO CRÍTICA ---
# Importa a função helper oficial da biblioteca LTX para decodificação.
try:
from ltx_video.models.autoencoders.vae_encode import vae_decode
except ImportError:
raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.")
class _SimpleVAEManager:
"""
Manages VAE decoding, now using the official 'vae_decode' helper function
for maximum compatibility.
"""
def __init__(self):
self.pipeline = None
self.device = torch.device("cpu")
self.autocast_dtype = torch.float32
def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
self.pipeline = pipeline
if device is not None:
self.device = torch.device(device)
logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
if autocast_dtype is not None:
self.autocast_dtype = autocast_dtype
@torch.no_grad()
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
"""
Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper.
"""
if self.pipeline is None:
raise RuntimeError("VAEManager: No pipeline has been attached.")
# Move os latentes para o dispositivo VAE dedicado.
latent_tensor_on_vae_device = latent_tensor.to(self.device)
# Prepara o tensor de timesteps no mesmo dispositivo.
num_items_in_batch = latent_tensor_on_vae_device.shape[0]
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
autocast_device_type = self.device.type
ctx = torch.autocast(
device_type=autocast_device_type,
dtype=self.autocast_dtype,
enabled=(autocast_device_type == 'cuda')
)
with ctx:
logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
# --- CORREÇÃO PRINCIPAL ---
# Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente.
# Esta função sabe como lidar com o argumento 'timestep'.
pixels = vae_decode(
latents=latent_tensor_on_vae_device,
vae=self.pipeline.vae,
is_video=True,
timestep=timestep_tensor,
vae_per_channel_normalize=True, # Importante manter este parâmetro consistente
)
# A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal.
pixels = pixels.clamp(0, 1)
logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
return pixels.cpu()
# Singleton global
vae_manager_singleton = _SimpleVAEManager()