Spaces:
Paused
Paused
| # vae_manager.py — versão simples (beta 1.0) | |
| # Responsável por decodificar latentes (B,C,T,H,W) → pixels (B,C,T,H',W') em [0,1]. | |
| import torch | |
| import contextlib | |
| class _SimpleVAEManager: | |
| def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32): | |
| """ | |
| pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...) | |
| device: "cuda" ou "cpu" onde a decodificação deve ocorrer | |
| autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32) | |
| """ | |
| self.pipeline = pipeline | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.autocast_dtype = autocast_dtype | |
| def attach_pipeline(self, pipeline, device=None, autocast_dtype=None): | |
| self.pipeline = pipeline | |
| if device is not None: | |
| self.device = device | |
| if autocast_dtype is not None: | |
| self.autocast_dtype = autocast_dtype | |
| def decode(self, latents_5d: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Decodifica todo o bloco 5D de uma vez, replicando o fluxo simples do deformes4D. | |
| Retorna tensor de pixels 5D em [0,1] com shape (B,C,T,H',W'). | |
| """ | |
| if self.pipeline is None: | |
| raise RuntimeError("VAE Manager sem pipeline. Chame attach_pipeline primeiro.") | |
| # Garante device correto | |
| latents_5d = latents_5d.to(self.device, non_blocking=True) | |
| ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext() | |
| with ctx: | |
| if hasattr(self.pipeline, "decode_latents"): | |
| pixels_5d = self.pipeline.decode_latents(latents_5d) | |
| elif hasattr(self.pipeline, "vae") and hasattr(self.pipeline.vae, "decode"): | |
| pixels_5d = self.pipeline.vae.decode(latents_5d) | |
| else: | |
| raise RuntimeError("Pipeline não expõe decode_latents nem vae.decode.") | |
| # Normaliza para [0,1] se vier em [-1,1] | |
| if pixels_5d.min() < 0: | |
| pixels_5d = (pixels_5d.clamp(-1, 1) + 1.0) / 2.0 | |
| else: | |
| pixels_5d = pixels_5d.clamp(0, 1) | |
| return pixels_5d | |
| # Singleton global de uso simples | |
| vae_manager_singleton = _SimpleVAEManager() | |