Spaces:
Paused
Paused
| # FILE: managers/vae_manager.py | |
| # DESCRIPTION: Singleton manager for VAE decoding operations, supporting dedicated GPU devices. | |
| import torch | |
| import contextlib | |
| import logging | |
| class _SimpleVAEManager: | |
| """ | |
| Manages VAE decoding. It's designed to be aware that the VAE might reside | |
| on a different GPU than the main generation pipeline (e.g., Transformer). | |
| """ | |
| def __init__(self): | |
| """Initializes the manager without a pipeline attached.""" | |
| self.pipeline = None | |
| self.device = torch.device("cpu") # Defaults to CPU until a device is attached. | |
| self.autocast_dtype = torch.float32 | |
| def attach_pipeline(self, pipeline, device=None, autocast_dtype=None): | |
| """ | |
| Attaches the main pipeline and, crucially, stores the specific device | |
| that this manager and its associated VAE should operate on. | |
| Args: | |
| pipeline: The main LTX video pipeline instance. | |
| device (torch.device or str): The target device for VAE operations (e.g., 'cuda:1'). | |
| autocast_dtype (torch.dtype): The precision for torch.autocast. | |
| """ | |
| self.pipeline = pipeline | |
| if device is not None: | |
| self.device = torch.device(device) | |
| logging.info(f"[VAEManager] VAE device successfully set to: {self.device}") | |
| if autocast_dtype is not None: | |
| self.autocast_dtype = autocast_dtype | |
| def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: | |
| """ | |
| Decodes a latent tensor into a pixel tensor. | |
| This method ensures that the decoding operation happens on the correct, | |
| potentially dedicated, VAE device. | |
| Args: | |
| latent_tensor (torch.Tensor): The latents to decode, typically on the main device or CPU. | |
| decode_timestep (float): The timestep for VAE decoding. | |
| Returns: | |
| torch.Tensor: The resulting pixel tensor, moved to the CPU for general use. | |
| """ | |
| if self.pipeline is None: | |
| raise RuntimeError("VAEManager: No pipeline has been attached. Call attach_pipeline() first.") | |
| if not hasattr(self.pipeline, 'vae'): | |
| raise AttributeError("VAEManager: The attached pipeline does not have a 'vae' attribute.") | |
| # 1. Move the input latents to the dedicated VAE device. This is the critical step. | |
| logging.debug(f"[VAEManager] Moving latents from {latent_tensor.device} to VAE device {self.device} for decoding.") | |
| latent_tensor_on_vae_device = latent_tensor.to(self.device) | |
| # 2. Get a reference to the VAE model (which is already on the correct device). | |
| vae = self.pipeline.vae | |
| # 3. Prepare other necessary tensors on the same VAE device. | |
| num_items_in_batch = latent_tensor_on_vae_device.shape[0] | |
| timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device) | |
| # 4. Set up the autocast context for the target device type. | |
| autocast_device_type = self.device.type | |
| ctx = torch.autocast( | |
| device_type=autocast_device_type, | |
| dtype=self.autocast_dtype, | |
| enabled=(autocast_device_type == 'cuda') | |
| ) | |
| # 5. Perform the decoding operation within the autocast context. | |
| with ctx: | |
| logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.") | |
| # The VAE expects latents scaled by its scaling factor. | |
| scaled_latents = latent_tensor_on_vae_device / vae.config.scaling_factor | |
| pixels = vae.decode(scaled_latents, timesteps=timestep_tensor).sample | |
| # 6. Post-process the output: normalize to [0, 1] range. | |
| pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0 | |
| # 7. Move the final pixel tensor to the CPU. This is a safe default, as subsequent | |
| # operations like video saving or UI display typically expect CPU tensors. | |
| logging.debug(f"[VAEManager] Decoding complete. Moving pixel tensor to CPU.") | |
| return pixels.cpu() | |
| # Create a single, global instance of the manager to be used throughout the application. | |
| vae_manager_singleton = _SimpleVAEManager() |