Spaces:
Paused
Paused
Update managers/vae_manager.py
Browse files- managers/vae_manager.py +23 -22
managers/vae_manager.py
CHANGED
|
@@ -22,34 +22,35 @@ class _SimpleVAEManager:
|
|
| 22 |
if autocast_dtype is not None:
|
| 23 |
self.autocast_dtype = autocast_dtype
|
| 24 |
|
| 25 |
-
@torch.no_grad()
|
| 26 |
-
def decode(self, latents_5d: torch.Tensor) -> torch.Tensor:
|
| 27 |
-
"""
|
| 28 |
-
Decodifica todo o bloco 5D de uma vez, replicando o fluxo simples do deformes4D.
|
| 29 |
-
Retorna tensor de pixels 5D em [0,1] com shape (B,C,T,H',W').
|
| 30 |
-
"""
|
| 31 |
-
if self.pipeline is None:
|
| 32 |
-
raise RuntimeError("VAE Manager sem pipeline. Chame attach_pipeline primeiro.")
|
| 33 |
|
| 34 |
-
# Garante device correto
|
| 35 |
-
latents_5d = latents_5d.to(self.device, non_blocking=True)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
|
| 38 |
with ctx:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
# Normaliza para [0,1] se vier em [-1,1]
|
| 47 |
-
if
|
| 48 |
-
|
| 49 |
else:
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
|
| 54 |
# Singleton global de uso simples
|
| 55 |
vae_manager_singleton = _SimpleVAEManager()
|
|
|
|
| 22 |
if autocast_dtype is not None:
|
| 23 |
self.autocast_dtype = autocast_dtype
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
@torch.no_grad()
|
| 28 |
+
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
| 29 |
+
|
| 30 |
+
# Garante device e dtype conforme runtime
|
| 31 |
+
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.autocast_dtype if self.device == "cuda" else latent_tensor.dtype)
|
| 32 |
+
|
| 33 |
+
# Constr贸i o vetor de timesteps (um por item no batch B)
|
| 34 |
+
num_items_in_batch = latent_tensor_gpu.shape[0]
|
| 35 |
+
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=latent_tensor_gpu.dtype)
|
| 36 |
+
|
| 37 |
ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
|
| 38 |
with ctx:
|
| 39 |
+
pixels = vae_decode(
|
| 40 |
+
latent_tensor_gpu,
|
| 41 |
+
self.pipeline.vae if hasattr(self.pipeline, "vae") else self.pipeline, # compat
|
| 42 |
+
is_video=True,
|
| 43 |
+
timestep=timestep_tensor,
|
| 44 |
+
vae_per_channel_normalize=True,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
# Normaliza para [0,1] se vier em [-1,1]
|
| 48 |
+
if pixels.min() < 0:
|
| 49 |
+
pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
|
| 50 |
else:
|
| 51 |
+
pixels = pixels.clamp(0, 1)
|
| 52 |
+
return pixels
|
| 53 |
+
|
| 54 |
|
| 55 |
# Singleton global de uso simples
|
| 56 |
vae_manager_singleton = _SimpleVAEManager()
|