Test / managers /vae_manager.py
eeuuia's picture
Update managers/vae_manager.py
b8a0748 verified
raw
history blame
4.26 kB
# FILE: managers/vae_manager.py
# DESCRIPTION: Singleton manager for VAE decoding operations, supporting dedicated GPU devices.
import torch
import contextlib
import logging
class _SimpleVAEManager:
"""
Manages VAE decoding. It's designed to be aware that the VAE might reside
on a different GPU than the main generation pipeline (e.g., Transformer).
"""
def __init__(self):
"""Initializes the manager without a pipeline attached."""
self.pipeline = None
self.device = torch.device("cpu") # Defaults to CPU until a device is attached.
self.autocast_dtype = torch.float32
def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
"""
Attaches the main pipeline and, crucially, stores the specific device
that this manager and its associated VAE should operate on.
Args:
pipeline: The main LTX video pipeline instance.
device (torch.device or str): The target device for VAE operations (e.g., 'cuda:1').
autocast_dtype (torch.dtype): The precision for torch.autocast.
"""
self.pipeline = pipeline
if device is not None:
self.device = torch.device(device)
logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
if autocast_dtype is not None:
self.autocast_dtype = autocast_dtype
@torch.no_grad()
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
"""
Decodes a latent tensor into a pixel tensor.
This method ensures that the decoding operation happens on the correct,
potentially dedicated, VAE device.
Args:
latent_tensor (torch.Tensor): The latents to decode, typically on the main device or CPU.
decode_timestep (float): The timestep for VAE decoding.
Returns:
torch.Tensor: The resulting pixel tensor, moved to the CPU for general use.
"""
if self.pipeline is None:
raise RuntimeError("VAEManager: No pipeline has been attached. Call attach_pipeline() first.")
if not hasattr(self.pipeline, 'vae'):
raise AttributeError("VAEManager: The attached pipeline does not have a 'vae' attribute.")
# 1. Move the input latents to the dedicated VAE device. This is the critical step.
logging.debug(f"[VAEManager] Moving latents from {latent_tensor.device} to VAE device {self.device} for decoding.")
latent_tensor_on_vae_device = latent_tensor.to(self.device)
# 2. Get a reference to the VAE model (which is already on the correct device).
vae = self.pipeline.vae
# 3. Prepare other necessary tensors on the same VAE device.
num_items_in_batch = latent_tensor_on_vae_device.shape[0]
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
# 4. Set up the autocast context for the target device type.
autocast_device_type = self.device.type
ctx = torch.autocast(
device_type=autocast_device_type,
dtype=self.autocast_dtype,
enabled=(autocast_device_type == 'cuda')
)
# 5. Perform the decoding operation within the autocast context.
with ctx:
logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
# The VAE expects latents scaled by its scaling factor.
scaled_latents = latent_tensor_on_vae_device / vae.config.scaling_factor
pixels = vae.decode(scaled_latents, timesteps=timestep_tensor).sample
# 6. Post-process the output: normalize to [0, 1] range.
pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
# 7. Move the final pixel tensor to the CPU. This is a safe default, as subsequent
# operations like video saving or UI display typically expect CPU tensors.
logging.debug(f"[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
return pixels.cpu()
# Create a single, global instance of the manager to be used throughout the application.
vae_manager_singleton = _SimpleVAEManager()