Spaces:
Paused
Paused
File size: 3,617 Bytes
1dc8d3d 8815ceb b8a0748 28859c6 8dfb40e 28859c6 8dfb40e 829e1b9 1dc8d3d 8815ceb b8a0748 1dc8d3d b8a0748 1dc8d3d b8a0748 8815ceb b8a0748 8815ceb 441491f b8a0748 1dc8d3d b8a0748 1dc8d3d b8a0748 1dc8d3d b8a0748 441491f b8a0748 8815ceb b8a0748 1dc8d3d 441491f 1dc8d3d b8a0748 1dc8d3d b8a0748 8815ceb 1dc8d3d b8a0748 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido)
import torch
import contextlib
import logging
import sys
from pathlib import Path
import os
import io
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
"""
Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
bibliotecas possam ser importadas.
"""
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
# Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
add_deps_to_path()
# --- IMPORTAÇÃO CRÍTICA ---
# Importa a função helper oficial da biblioteca LTX para decodificação.
try:
from ltx_video.models.autoencoders.vae_encode import vae_decode
except ImportError:
raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.")
class _SimpleVAEManager:
"""
Manages VAE decoding, now using the official 'vae_decode' helper function
for maximum compatibility.
"""
def __init__(self):
self.pipeline = None
self.device = torch.device("cpu")
self.autocast_dtype = torch.float32
def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
self.pipeline = pipeline
if device is not None:
self.device = torch.device(device)
logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
if autocast_dtype is not None:
self.autocast_dtype = autocast_dtype
@torch.no_grad()
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
"""
Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper.
"""
if self.pipeline is None:
raise RuntimeError("VAEManager: No pipeline has been attached.")
# Move os latentes para o dispositivo VAE dedicado.
latent_tensor_on_vae_device = latent_tensor.to(self.device)
# Prepara o tensor de timesteps no mesmo dispositivo.
num_items_in_batch = latent_tensor_on_vae_device.shape[0]
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
autocast_device_type = self.device.type
ctx = torch.autocast(
device_type=autocast_device_type,
dtype=self.autocast_dtype,
enabled=(autocast_device_type == 'cuda')
)
with ctx:
logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
# --- CORREÇÃO PRINCIPAL ---
# Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente.
# Esta função sabe como lidar com o argumento 'timestep'.
pixels = vae_decode(
latents=latent_tensor_on_vae_device,
vae=self.pipeline.vae,
is_video=True,
timestep=timestep_tensor,
vae_per_channel_normalize=True, # Importante manter este parâmetro consistente
)
# A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal.
pixels = pixels.clamp(0, 1)
logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
return pixels.cpu()
# Singleton global
vae_manager_singleton = _SimpleVAEManager() |