File size: 3,617 Bytes
1dc8d3d
8815ceb
 
 
b8a0748
28859c6
8dfb40e
28859c6
 
8dfb40e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829e1b9
1dc8d3d
 
 
 
 
 
 
 
8815ceb
b8a0748
1dc8d3d
 
b8a0748
 
 
1dc8d3d
b8a0748
8815ceb
 
 
 
b8a0748
 
8815ceb
 
 
441491f
 
b8a0748
1dc8d3d
b8a0748
 
1dc8d3d
 
 
b8a0748
 
1dc8d3d
b8a0748
 
441491f
b8a0748
 
 
 
 
 
 
8815ceb
b8a0748
1dc8d3d
 
 
 
 
 
 
 
 
 
 
441491f
1dc8d3d
 
b8a0748
1dc8d3d
b8a0748
8815ceb
1dc8d3d
b8a0748
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido)

import torch
import contextlib
import logging
import sys
from pathlib import Path
import os
import io

LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")

def add_deps_to_path():
    """
    Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
    bibliotecas possam ser importadas.
    """
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)
        logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")

# Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
add_deps_to_path()


# --- IMPORTAÇÃO CRÍTICA ---
# Importa a função helper oficial da biblioteca LTX para decodificação.
try:
    from ltx_video.models.autoencoders.vae_encode import vae_decode
except ImportError:
    raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.")


class _SimpleVAEManager:
    """
    Manages VAE decoding, now using the official 'vae_decode' helper function
    for maximum compatibility.
    """
    def __init__(self):
        self.pipeline = None
        self.device = torch.device("cpu")
        self.autocast_dtype = torch.float32

    def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
        self.pipeline = pipeline
        if device is not None:
            self.device = torch.device(device)
            logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
        if autocast_dtype is not None:
            self.autocast_dtype = autocast_dtype

    @torch.no_grad()
    def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
        """
        Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper.
        """
        if self.pipeline is None:
            raise RuntimeError("VAEManager: No pipeline has been attached.")
        
        # Move os latentes para o dispositivo VAE dedicado.
        latent_tensor_on_vae_device = latent_tensor.to(self.device)

        # Prepara o tensor de timesteps no mesmo dispositivo.
        num_items_in_batch = latent_tensor_on_vae_device.shape[0]
        timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
    
        autocast_device_type = self.device.type
        ctx = torch.autocast(
            device_type=autocast_device_type,
            dtype=self.autocast_dtype,
            enabled=(autocast_device_type == 'cuda')
        )
        
        with ctx:
            logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
            
            # --- CORREÇÃO PRINCIPAL ---
            # Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente.
            # Esta função sabe como lidar com o argumento 'timestep'.
            pixels = vae_decode(
                latents=latent_tensor_on_vae_device,
                vae=self.pipeline.vae,
                is_video=True,
                timestep=timestep_tensor,
                vae_per_channel_normalize=True, # Importante manter este parâmetro consistente
            )
    
        # A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal.
        pixels = pixels.clamp(0, 1)
        
        logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
        return pixels.cpu()

# Singleton global
vae_manager_singleton = _SimpleVAEManager()