EuuIia commited on
Commit
441491f
verified
1 Parent(s): 63cc2df

Update managers/vae_manager.py

Browse files
Files changed (1) hide show
  1. managers/vae_manager.py +23 -22
managers/vae_manager.py CHANGED
@@ -22,34 +22,35 @@ class _SimpleVAEManager:
22
  if autocast_dtype is not None:
23
  self.autocast_dtype = autocast_dtype
24
 
25
- @torch.no_grad()
26
- def decode(self, latents_5d: torch.Tensor) -> torch.Tensor:
27
- """
28
- Decodifica todo o bloco 5D de uma vez, replicando o fluxo simples do deformes4D.
29
- Retorna tensor de pixels 5D em [0,1] com shape (B,C,T,H',W').
30
- """
31
- if self.pipeline is None:
32
- raise RuntimeError("VAE Manager sem pipeline. Chame attach_pipeline primeiro.")
33
 
34
- # Garante device correto
35
- latents_5d = latents_5d.to(self.device, non_blocking=True)
36
 
 
 
 
 
 
 
 
 
 
 
37
  ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
38
  with ctx:
39
- if hasattr(self.pipeline, "decode_latents"):
40
- pixels_5d = self.pipeline.decode_latents(latents_5d)
41
- elif hasattr(self.pipeline, "vae") and hasattr(self.pipeline.vae, "decode"):
42
- pixels_5d = self.pipeline.vae.decode(latents_5d)
43
- else:
44
- raise RuntimeError("Pipeline n茫o exp玫e decode_latents nem vae.decode.")
45
-
 
46
  # Normaliza para [0,1] se vier em [-1,1]
47
- if pixels_5d.min() < 0:
48
- pixels_5d = (pixels_5d.clamp(-1, 1) + 1.0) / 2.0
49
  else:
50
- pixels_5d = pixels_5d.clamp(0, 1)
51
-
52
- return pixels_5d
53
 
54
  # Singleton global de uso simples
55
  vae_manager_singleton = _SimpleVAEManager()
 
22
  if autocast_dtype is not None:
23
  self.autocast_dtype = autocast_dtype
24
 
 
 
 
 
 
 
 
 
25
 
 
 
26
 
27
+ @torch.no_grad()
28
+ def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
29
+
30
+ # Garante device e dtype conforme runtime
31
+ latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.autocast_dtype if self.device == "cuda" else latent_tensor.dtype)
32
+
33
+ # Constr贸i o vetor de timesteps (um por item no batch B)
34
+ num_items_in_batch = latent_tensor_gpu.shape[0]
35
+ timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=latent_tensor_gpu.dtype)
36
+
37
  ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
38
  with ctx:
39
+ pixels = vae_decode(
40
+ latent_tensor_gpu,
41
+ self.pipeline.vae if hasattr(self.pipeline, "vae") else self.pipeline, # compat
42
+ is_video=True,
43
+ timestep=timestep_tensor,
44
+ vae_per_channel_normalize=True,
45
+ )
46
+
47
  # Normaliza para [0,1] se vier em [-1,1]
48
+ if pixels.min() < 0:
49
+ pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
50
  else:
51
+ pixels = pixels.clamp(0, 1)
52
+ return pixels
53
+
54
 
55
  # Singleton global de uso simples
56
  vae_manager_singleton = _SimpleVAEManager()