alex commited on
Commit
e37991a
·
1 Parent(s): 2c8ec61

progress bar fixed

Browse files
Files changed (3) hide show
  1. app.py +15 -7
  2. humo/generate.py +9 -3
  3. humo/generate_1_7B.py +326 -46
app.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import subprocess
6
  import uuid
7
  import shutil
8
-
9
 
10
 
11
  from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download
@@ -93,7 +93,6 @@ config = load_config(
93
  )
94
  runner = create_object(config)
95
 
96
-
97
  os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space") # or another writable path
98
 
99
  def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip",
@@ -110,7 +109,7 @@ def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile
110
  # restore_inductor_cache_from_hub("alexnasa/humo-compiled")
111
 
112
 
113
- def get_duration(prompt_text, steps, image_file, audio_file_path, max_duration, session_id):
114
 
115
  return calculate_required_time(steps, max_duration)
116
 
@@ -124,6 +123,15 @@ def calculate_required_time(steps, max_duration):
124
  70: 13,
125
  95: 21,
126
  }
 
 
 
 
 
 
 
 
 
127
  each_step_s = max_duration_duration_mapping[max_duration]
128
  duration_s = (each_step_s * steps) + warmup_s
129
 
@@ -143,7 +151,7 @@ def update_required_time(steps, max_duration):
143
  return get_required_time_string(steps, max_duration)
144
 
145
 
146
- def generate_scene(prompt_text, steps, image_paths, audio_file_path, max_duration = 3, session_id = None):
147
 
148
  prompt_text_check = (prompt_text or "").strip()
149
  if not prompt_text_check:
@@ -152,7 +160,7 @@ def generate_scene(prompt_text, steps, image_paths, audio_file_path, max_duratio
152
  if not audio_file_path and not image_paths:
153
  raise gr.Error("Please provide a reference image or a lipsync audio.")
154
 
155
- return run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration, session_id)
156
 
157
  def upload_inductor_cache_to_hub(
158
  repo_id: str,
@@ -206,7 +214,7 @@ def upload_inductor_cache_to_hub(
206
 
207
 
208
  @spaces.GPU(duration=get_duration)
209
- def run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration = 3, session_id = None):
210
 
211
  if session_id is None:
212
  session_id = uuid.uuid4().hex
@@ -267,7 +275,6 @@ def run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration
267
  width, height = 832, 480
268
 
269
 
270
- # Run inference
271
  runner.inference_loop(
272
  prompt_text,
273
  img_paths,
@@ -280,6 +287,7 @@ def run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration
280
  steps,
281
  frames = int(max_duration),
282
  tea_cache_l1_thresh = 0.0,
 
283
  )
284
 
285
  # Return resulting video path
 
5
  import subprocess
6
  import uuid
7
  import shutil
8
+ from tqdm import tqdm
9
 
10
 
11
  from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download
 
93
  )
94
  runner = create_object(config)
95
 
 
96
  os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space") # or another writable path
97
 
98
  def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip",
 
109
  # restore_inductor_cache_from_hub("alexnasa/humo-compiled")
110
 
111
 
112
+ def get_duration(prompt_text, steps, image_file, audio_file_path, max_duration, session_id, progress):
113
 
114
  return calculate_required_time(steps, max_duration)
115
 
 
123
  70: 13,
124
  95: 21,
125
  }
126
+
127
+ # Humo 1.7
128
+ # max_duration_duration_mapping = {
129
+ # 20: 2,
130
+ # 45: 2,
131
+ # 70: 5,
132
+ # 95: 6,
133
+ # }
134
+
135
  each_step_s = max_duration_duration_mapping[max_duration]
136
  duration_s = (each_step_s * steps) + warmup_s
137
 
 
151
  return get_required_time_string(steps, max_duration)
152
 
153
 
154
+ def generate_scene(prompt_text, steps, image_paths, audio_file_path, max_duration = 3, session_id = None, progress=gr.Progress(),):
155
 
156
  prompt_text_check = (prompt_text or "").strip()
157
  if not prompt_text_check:
 
160
  if not audio_file_path and not image_paths:
161
  raise gr.Error("Please provide a reference image or a lipsync audio.")
162
 
163
+ return run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration, session_id, progress)
164
 
165
  def upload_inductor_cache_to_hub(
166
  repo_id: str,
 
214
 
215
 
216
  @spaces.GPU(duration=get_duration)
217
+ def run_pipeline(prompt_text, steps, image_paths, audio_file_path, max_duration = 3, session_id = None, progress=gr.Progress(),):
218
 
219
  if session_id is None:
220
  session_id = uuid.uuid4().hex
 
275
  width, height = 832, 480
276
 
277
 
 
278
  runner.inference_loop(
279
  prompt_text,
280
  img_paths,
 
287
  steps,
288
  frames = int(max_duration),
289
  tea_cache_l1_thresh = 0.0,
290
+ progress_bar_cmd=progress
291
  )
292
 
293
  # Return resulting video path
humo/generate.py CHANGED
@@ -680,6 +680,7 @@ class Generator():
680
  n_prompt="",
681
  seed=-1,
682
  tea_cache_l1_thresh = 0.0,
 
683
  device = get_device(),
684
  ):
685
 
@@ -796,8 +797,11 @@ class Generator():
796
  arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
797
 
798
  torch.cuda.empty_cache()
 
 
 
799
  # self.dit.to(device=get_device())
800
- for _, t in enumerate(tqdm(timesteps)):
801
  timestep = [t]
802
  timestep = torch.stack(timestep)
803
 
@@ -823,6 +827,7 @@ class Generator():
823
  del timestep
824
  torch.cuda.empty_cache()
825
 
 
826
  x0 = latents
827
  x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
828
 
@@ -848,7 +853,7 @@ class Generator():
848
  return videos[0] # if get_local_rank() == 0 else None
849
 
850
 
851
- def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0):
852
 
853
  video = self.inference(
854
  prompt,
@@ -861,7 +866,8 @@ class Generator():
861
  sampling_steps=steps,
862
  inference_mode = inference_mode,
863
  tea_cache_l1_thresh = tea_cache_l1_thresh,
864
- seed=seed
 
865
  )
866
 
867
  torch.cuda.empty_cache()
 
680
  n_prompt="",
681
  seed=-1,
682
  tea_cache_l1_thresh = 0.0,
683
+ progress_bar_cmd = None,
684
  device = get_device(),
685
  ):
686
 
 
797
  arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
798
 
799
  torch.cuda.empty_cache()
800
+
801
+ total_steps = len(timesteps)
802
+
803
  # self.dit.to(device=get_device())
804
+ for i, t in progress_bar_cmd.tqdm(enumerate(timesteps), desc=f"/{total_steps} Steps"):
805
  timestep = [t]
806
  timestep = torch.stack(timestep)
807
 
 
827
  del timestep
828
  torch.cuda.empty_cache()
829
 
830
+
831
  x0 = latents
832
  x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
833
 
 
853
  return videos[0] # if get_local_rank() == 0 else None
854
 
855
 
856
+ def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, progress_bar_cmd = None, seed = 0):
857
 
858
  video = self.inference(
859
  prompt,
 
866
  sampling_steps=steps,
867
  inference_mode = inference_mode,
868
  tea_cache_l1_thresh = tea_cache_l1_thresh,
869
+ seed=seed,
870
+ progress_bar_cmd = progress_bar_cmd
871
  )
872
 
873
  torch.cuda.empty_cache()
humo/generate_1_7B.py CHANGED
@@ -18,6 +18,7 @@ import gc
18
  import random
19
  import sys
20
  import mediapy
 
21
  import torch
22
  import torch.distributed as dist
23
  from omegaconf import DictConfig, ListConfig, OmegaConf
@@ -59,7 +60,15 @@ import torch.cuda.amp as amp
59
  from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
60
  from humo.utils.audio_processor_whisper import AudioProcessor
61
  from humo.utils.wav2vec import linear_interpolation_fps
 
62
 
 
 
 
 
 
 
 
63
 
64
  image_transform = Compose([
65
  ToTensor(),
@@ -96,14 +105,130 @@ def clever_format(nums, format="%.2f"):
96
  return clever_nums
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  class Generator():
100
  def __init__(self, config: DictConfig):
101
  self.config = config.copy()
102
  OmegaConf.set_readonly(self.config, True)
103
  self.logger = get_logger(self.__class__.__name__)
104
- self.configure_models()
105
 
106
  # init_torch(cudnn_benchmark=False)
 
 
 
 
 
107
 
108
  def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
109
  device_mesh = None
@@ -115,43 +240,63 @@ class Generator():
115
  device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
116
  return device_mesh, fsdp_strategy
117
 
 
118
  def configure_models(self):
119
- self.configure_dit_model(device="cpu")
120
- self.configure_vae_model()
 
 
 
 
 
 
 
121
  if self.config.generation.get('extract_audio_feat', False):
122
  self.configure_wav2vec(device="cpu")
123
- self.configure_text_model(device="cpu")
 
 
 
 
 
 
 
124
 
125
- # Initialize fsdp.
126
- self.configure_dit_fsdp_model()
127
- self.configure_text_fsdp_model()
128
 
129
  def configure_dit_model(self, device=get_device()):
130
 
131
  init_unified_parallel(self.config.dit.sp_size)
132
  self.sp_size = get_unified_parallel_world_size()
133
-
134
- # Create dit model.
135
  init_device = "meta"
136
  with torch.device(init_device):
137
  self.dit = create_object(self.config.dit.model)
 
138
  self.logger.info(f"Load DiT model on {init_device}.")
139
  self.dit.eval().requires_grad_(False)
140
 
141
  # Load dit checkpoint.
142
  path = self.config.dit.checkpoint_dir
 
 
 
 
 
 
 
143
  if path.endswith(".pth"):
144
- state = torch.load(path, map_location=device, mmap=True)
 
 
145
  missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
146
  self.logger.info(
147
- f"dit loaded from {path}. "
148
- f"Missing keys: {len(missing_keys)}, "
149
- f"Unexpected keys: {len(unexpected_keys)}"
150
  )
151
  else:
152
  from safetensors.torch import load_file
153
  import json
154
- def load_custom_sharded_weights(model_dir, base_name, device=device):
155
  index_path = f"{model_dir}/{base_name}.safetensors.index.json"
156
  with open(index_path, "r") as f:
157
  index = json.load(f)
@@ -160,23 +305,28 @@ class Generator():
160
  state_dict = {}
161
  for shard_file in shard_files:
162
  shard_path = f"{model_dir}/{shard_file}"
163
- shard_state = load_file(shard_path)
164
- shard_state = {k: v.to(device) for k, v in shard_state.items()}
 
 
165
  state_dict.update(shard_state)
166
  return state_dict
167
- state = load_custom_sharded_weights(path, 'humo', device)
 
168
  self.dit.load_state_dict(state, strict=False, assign=True)
169
-
170
  self.dit = meta_non_persistent_buffer_init_fn(self.dit)
171
- if device in [get_device(), "cuda"]:
172
- self.dit.to(get_device())
 
173
 
174
  # Print model size.
175
  params = sum(p.numel() for p in self.dit.parameters())
176
  self.logger.info(
177
  f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
178
  )
179
-
 
180
  def configure_vae_model(self, device=get_device()):
181
  self.vae_stride = self.config.vae.vae_stride
182
  self.vae = WanVAE(
@@ -216,15 +366,93 @@ class Generator():
216
 
217
 
218
  def configure_dit_fsdp_model(self):
219
- self.dit.to(get_device())
220
-
221
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
 
224
  def configure_text_fsdp_model(self):
225
- self.text_encoder.to(get_device())
226
-
227
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
 
230
  def load_image_latent_ref_id(self, path: str, size, device):
@@ -390,7 +618,6 @@ class Generator():
390
  neg
391
 
392
  return noise_pred
393
-
394
 
395
  @torch.no_grad()
396
  def inference(self,
@@ -401,20 +628,22 @@ class Generator():
401
  frame_num=81,
402
  shift=5.0,
403
  sample_solver='unipc',
 
404
  sampling_steps=50,
405
  n_prompt="",
406
  seed=-1,
407
- offload_model=True,
408
  device = get_device(),
409
  ):
410
 
411
- self.vae.model.to(device=device)
412
  if img_path is not None:
413
  latents_ref = self.load_image_latent_ref_id(img_path, size, device)
414
  else:
415
  latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
416
 
417
- self.vae.model.to(device="cpu")
 
418
  latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
419
 
420
  # audio
@@ -456,10 +685,10 @@ class Generator():
456
  seed_g = torch.Generator(device=device)
457
  seed_g.manual_seed(seed)
458
 
459
- self.text_encoder.model.to(device)
460
  context = self.text_encoder([input_prompt], device)
461
  context_null = self.text_encoder([n_prompt], device)
462
- self.text_encoder.model.cpu()
463
 
464
  noise = [
465
  torch.randn(
@@ -477,10 +706,9 @@ class Generator():
477
  yield
478
 
479
  no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
480
- # step_change = self.config.generation.step_change # 980
481
 
482
  # evaluation mode
483
- with amp.autocast(dtype=torch.bfloat16), torch.no_grad(), no_sync():
484
 
485
  if sample_solver == 'unipc':
486
  sample_scheduler = FlowUniPCMultistepScheduler(
@@ -500,7 +728,7 @@ class Generator():
500
  arg_null = {'context': context_null, 'seq_len': seq_len, 'audio': audio_emb_neg}
501
 
502
  torch.cuda.empty_cache()
503
- self.dit.to(device=get_device())
504
  for _, t in enumerate(tqdm(timesteps)):
505
  timestep = [t]
506
  timestep = torch.stack(timestep)
@@ -527,12 +755,13 @@ class Generator():
527
  x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
528
 
529
  # if offload_model:
530
- self.dit.cpu()
 
531
  torch.cuda.empty_cache()
532
  # if get_local_rank() == 0:
533
- self.vae.model.to(device=device)
534
  videos = self.vae.decode(x0)
535
- self.vae.model.to(device="cpu")
536
 
537
  del noise, latents, noise_pred
538
  del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
@@ -547,8 +776,7 @@ class Generator():
547
  return videos[0] # if get_local_rank() == 0 else None
548
 
549
 
550
- def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, width = 832, height = 480, steps=50, frames = 97, seed = 0):
551
- print(f'ref_img_path:{ref_img_path}')
552
 
553
  video = self.inference(
554
  prompt,
@@ -559,14 +787,14 @@ class Generator():
559
  shift=self.config.diffusion.timesteps.sampling.shift,
560
  sample_solver='unipc',
561
  sampling_steps=steps,
562
- seed=seed,
563
- offload_model=False,
 
564
  )
565
 
566
  torch.cuda.empty_cache()
567
  gc.collect()
568
 
569
-
570
  # Save samples.
571
  if get_sequence_parallel_rank() == 0:
572
  pathname = self.save_sample(
@@ -580,7 +808,6 @@ class Generator():
580
  del video, prompt
581
  torch.cuda.empty_cache()
582
  gc.collect()
583
-
584
 
585
 
586
  def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
@@ -619,4 +846,57 @@ class Generator():
619
  raise NotImplementedError
620
  assert isinstance(pos_prompts, ListConfig)
621
 
622
- return pos_prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import random
19
  import sys
20
  import mediapy
21
+ import numpy as np
22
  import torch
23
  import torch.distributed as dist
24
  from omegaconf import DictConfig, ListConfig, OmegaConf
 
60
  from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
61
  from humo.utils.audio_processor_whisper import AudioProcessor
62
  from humo.utils.wav2vec import linear_interpolation_fps
63
+ from torchao.quantization import quantize_
64
 
65
+ import torch._dynamo as dynamo
66
+ dynamo.config.capture_scalar_outputs = True
67
+ torch.set_float32_matmul_precision("high")
68
+
69
+ import torch
70
+ import torch.nn as nn
71
+ import transformer_engine.pytorch as te
72
 
73
  image_transform = Compose([
74
  ToTensor(),
 
105
  return clever_nums
106
 
107
 
108
+
109
+ # --- put near your imports ---
110
+ import torch
111
+ import torch.nn as nn
112
+ import contextlib
113
+ import transformer_engine.pytorch as te
114
+
115
+ # FP8 autocast compatibility for different TE versions
116
+ try:
117
+ # Preferred modern API
118
+ from transformer_engine.pytorch import fp8_autocast
119
+ try:
120
+ # Newer TE: use recipe-based API
121
+ from transformer_engine.common.recipe import DelayedScaling, Format
122
+ def make_fp8_ctx(enabled: bool = True):
123
+ if not enabled:
124
+ return contextlib.nullcontext()
125
+ fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # E4M3 format
126
+ return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
127
+ except Exception:
128
+ # Very old variant that might still accept fp8_format directly
129
+ def make_fp8_ctx(enabled: bool = True):
130
+ # If TE doesn't have FP8Format, just no-op
131
+ if not hasattr(te, "FP8Format"):
132
+ return contextlib.nullcontext()
133
+ return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3)
134
+ except Exception:
135
+ # TE not present or totally incompatible — no-op
136
+ def make_fp8_ctx(enabled: bool = True):
137
+ return contextlib.nullcontext()
138
+
139
+
140
+ # TE sometimes exposes Linear at different paths; this normalizes it.
141
+ try:
142
+ TELinear = te.Linear
143
+ except AttributeError: # very old layouts
144
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
145
+
146
+ # --- near imports ---
147
+ import torch
148
+ import torch.nn as nn
149
+ import transformer_engine.pytorch as te
150
+
151
+ try:
152
+ TELinear = te.Linear
153
+ except AttributeError:
154
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
155
+
156
+ import torch
157
+ import torch.nn as nn
158
+ import transformer_engine.pytorch as te
159
+
160
+ try:
161
+ TELinear = te.Linear
162
+ except AttributeError:
163
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
164
+
165
+ def _default_te_allow(fullname: str, lin: nn.Linear) -> bool:
166
+ """
167
+ Allow TE only where it's shape-safe & beneficial.
168
+ Skip small/special layers (time/timestep/pos embeds, heads).
169
+ Enforce multiples of 16 for in/out features (FP8 kernel friendly).
170
+ Also skip very small projections likely to see M=1.
171
+ """
172
+ blocked_keywords = (
173
+ "time_embedding", "timestep", "time_embed",
174
+ "time_projection", "pos_embedding", "pos_embed",
175
+ "to_logits", "logits", "final_proj", "proj_out", "output_projection",
176
+ )
177
+ if any(k in fullname for k in blocked_keywords):
178
+ return False
179
+
180
+ # TE FP8 kernels like K, N divisible by 16
181
+ if lin.in_features % 16 != 0 or lin.out_features % 16 != 0:
182
+ return False
183
+
184
+ # Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs
185
+ if lin.in_features < 512 or lin.out_features < 512:
186
+ return False
187
+
188
+ # Whitelist: only convert inside transformer blocks if you know their prefix
189
+ # This further reduces risk of catching special heads elsewhere.
190
+ allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn")
191
+ if not any(tok in fullname for tok in allowed_context):
192
+ return False
193
+
194
+ return True
195
+
196
+ @torch.no_grad()
197
+ def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""):
198
+ for name, child in list(module.named_children()):
199
+ full = f"{_prefix}.{name}" if _prefix else name
200
+ convert_linears_to_te_fp8(child, allow_pred, full)
201
+
202
+ if isinstance(child, nn.Linear):
203
+ if allow_pred is not None and not allow_pred(full, child):
204
+ continue
205
+
206
+ te_lin = TELinear(
207
+ in_features=child.in_features,
208
+ out_features=child.out_features,
209
+ bias=(child.bias is not None),
210
+ params_dtype=torch.bfloat16,
211
+ ).to(child.weight.device)
212
+
213
+ te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype))
214
+ if child.bias is not None:
215
+ te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype))
216
+
217
+ setattr(module, name, te_lin)
218
+ return module
219
+
220
  class Generator():
221
  def __init__(self, config: DictConfig):
222
  self.config = config.copy()
223
  OmegaConf.set_readonly(self.config, True)
224
  self.logger = get_logger(self.__class__.__name__)
 
225
 
226
  # init_torch(cudnn_benchmark=False)
227
+ self.configure_models()
228
+
229
+ def entrypoint(self):
230
+
231
+ self.inference_loop()
232
 
233
  def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
234
  device_mesh = None
 
240
  device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
241
  return device_mesh, fsdp_strategy
242
 
243
+
244
  def configure_models(self):
245
+ self.configure_dit_model(device="cuda")
246
+
247
+ self.dit.eval().to("cuda")
248
+ convert_linears_to_te_fp8(self.dit)
249
+
250
+ self.dit = torch.compile(self.dit, )
251
+
252
+
253
+ self.configure_vae_model(device="cuda")
254
  if self.config.generation.get('extract_audio_feat', False):
255
  self.configure_wav2vec(device="cpu")
256
+ self.configure_text_model(device="cuda")
257
+
258
+ # # Initialize fsdp.
259
+ # self.configure_dit_fsdp_model()
260
+ # self.configure_text_fsdp_model()
261
+
262
+ # quantize_(self.text_encoder, Int8WeightOnlyConfig())
263
+ # quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig())
264
 
 
 
 
265
 
266
  def configure_dit_model(self, device=get_device()):
267
 
268
  init_unified_parallel(self.config.dit.sp_size)
269
  self.sp_size = get_unified_parallel_world_size()
270
+
271
+ # Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet).
272
  init_device = "meta"
273
  with torch.device(init_device):
274
  self.dit = create_object(self.config.dit.model)
275
+ self.dit = self.dit.to(dtype=torch.bfloat16) # or: self.dit.bfloat16()
276
  self.logger.info(f"Load DiT model on {init_device}.")
277
  self.dit.eval().requires_grad_(False)
278
 
279
  # Load dit checkpoint.
280
  path = self.config.dit.checkpoint_dir
281
+
282
+ def _cast_state_dict_to_bf16(state):
283
+ for k, v in state.items():
284
+ if isinstance(v, torch.Tensor) and v.is_floating_point():
285
+ state[k] = v.to(dtype=torch.bfloat16, copy=False)
286
+ return state
287
+
288
  if path.endswith(".pth"):
289
+ # Load to CPU first; we’ll move the model later.
290
+ state = torch.load(path, map_location="cpu", mmap=True)
291
+ state = _cast_state_dict_to_bf16(state)
292
  missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
293
  self.logger.info(
294
+ f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}"
 
 
295
  )
296
  else:
297
  from safetensors.torch import load_file
298
  import json
299
+ def load_custom_sharded_weights(model_dir, base_name):
300
  index_path = f"{model_dir}/{base_name}.safetensors.index.json"
301
  with open(index_path, "r") as f:
302
  index = json.load(f)
 
305
  state_dict = {}
306
  for shard_file in shard_files:
307
  shard_path = f"{model_dir}/{shard_file}"
308
+ # Load on CPU, then cast to bf16; we’ll move the whole module later.
309
+ shard_state = load_file(shard_path, device="cpu")
310
+ shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v)
311
+ for k, v in shard_state.items()}
312
  state_dict.update(shard_state)
313
  return state_dict
314
+
315
+ state = load_custom_sharded_weights(path, 'humo')
316
  self.dit.load_state_dict(state, strict=False, assign=True)
317
+
318
  self.dit = meta_non_persistent_buffer_init_fn(self.dit)
319
+
320
+ target_device = get_device() if device in [get_device(), "cuda"] else device
321
+ self.dit.to(target_device) # dtype already bf16
322
 
323
  # Print model size.
324
  params = sum(p.numel() for p in self.dit.parameters())
325
  self.logger.info(
326
  f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
327
  )
328
+
329
+
330
  def configure_vae_model(self, device=get_device()):
331
  self.vae_stride = self.config.vae.vae_stride
332
  self.vae = WanVAE(
 
366
 
367
 
368
  def configure_dit_fsdp_model(self):
369
+ from humo.models.wan_modules.model_humo import WanAttentionBlock
370
+
371
+ dit_blocks = (WanAttentionBlock,)
372
+
373
+ # Init model_shard_cpu_group for saving checkpoint with sharded state_dict.
374
+ init_model_shard_cpu_group(
375
+ self.config.dit.fsdp.sharding_strategy,
376
+ self.config.dit.fsdp.get("device_mesh", None),
377
+ )
378
+
379
+ # Assert that dit has wrappable blocks.
380
+ assert any(isinstance(m, dit_blocks) for m in self.dit.modules())
381
+
382
+ # Define wrap policy on all dit blocks.
383
+ def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
384
+ return recurse or isinstance(module, dit_blocks)
385
+
386
+ # Configure FSDP settings.
387
+ device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
388
+ self.config.dit.fsdp.sharding_strategy,
389
+ self.config.dit.fsdp.get("device_mesh", None),
390
+ )
391
+ settings = dict(
392
+ auto_wrap_policy=custom_auto_wrap_policy,
393
+ sharding_strategy=fsdp_strategy,
394
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
395
+ device_id=get_local_rank(),
396
+ use_orig_params=False,
397
+ sync_module_states=True,
398
+ forward_prefetch=True,
399
+ limit_all_gathers=False, # False for ZERO2.
400
+ mixed_precision=MixedPrecision(
401
+ param_dtype=torch.bfloat16,
402
+ reduce_dtype=torch.float32,
403
+ buffer_dtype=torch.float32,
404
+ ),
405
+ device_mesh=device_mesh,
406
+ param_init_fn=meta_param_init_fn,
407
+ )
408
+
409
+ # Apply FSDP.
410
+ self.dit = FullyShardedDataParallel(self.dit, **settings)
411
+ # self.dit.to(get_device())
412
 
413
 
414
  def configure_text_fsdp_model(self):
415
+ # If FSDP is not enabled, put text_encoder to GPU and return.
416
+ if not self.config.text.fsdp.enabled:
417
+ self.text_encoder.to(get_device())
418
+ return
419
+
420
+ # from transformers.models.t5.modeling_t5 import T5Block
421
+ from humo.models.wan_modules.t5 import T5SelfAttention
422
+
423
+ text_blocks = (torch.nn.Embedding, T5SelfAttention)
424
+ # text_blocks_names = ("QWenBlock", "QWenModel") # QWen cannot be imported. Use str.
425
+
426
+ def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
427
+ return (
428
+ recurse
429
+ or isinstance(module, text_blocks)
430
+ )
431
+
432
+ # Apply FSDP.
433
+ text_encoder_dtype = getattr(torch, self.config.text.dtype)
434
+ device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
435
+ self.config.text.fsdp.sharding_strategy,
436
+ self.config.text.fsdp.get("device_mesh", None),
437
+ )
438
+ self.text_encoder = FullyShardedDataParallel(
439
+ module=self.text_encoder,
440
+ auto_wrap_policy=custom_auto_wrap_policy,
441
+ sharding_strategy=fsdp_strategy,
442
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
443
+ device_id=get_local_rank(),
444
+ use_orig_params=False,
445
+ sync_module_states=False,
446
+ forward_prefetch=True,
447
+ limit_all_gathers=True,
448
+ mixed_precision=MixedPrecision(
449
+ param_dtype=text_encoder_dtype,
450
+ reduce_dtype=text_encoder_dtype,
451
+ buffer_dtype=text_encoder_dtype,
452
+ ),
453
+ device_mesh=device_mesh,
454
+ )
455
+ self.text_encoder.to(get_device()).requires_grad_(False)
456
 
457
 
458
  def load_image_latent_ref_id(self, path: str, size, device):
 
618
  neg
619
 
620
  return noise_pred
 
621
 
622
  @torch.no_grad()
623
  def inference(self,
 
628
  frame_num=81,
629
  shift=5.0,
630
  sample_solver='unipc',
631
+ inference_mode='TIA',
632
  sampling_steps=50,
633
  n_prompt="",
634
  seed=-1,
635
+ tea_cache_l1_thresh = 0.0,
636
  device = get_device(),
637
  ):
638
 
639
+ # self.vae.model.to(device=device)
640
  if img_path is not None:
641
  latents_ref = self.load_image_latent_ref_id(img_path, size, device)
642
  else:
643
  latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
644
 
645
+ # self.vae.model.to(device="cpu")
646
+
647
  latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
648
 
649
  # audio
 
685
  seed_g = torch.Generator(device=device)
686
  seed_g.manual_seed(seed)
687
 
688
+ # self.text_encoder.model.to(device)
689
  context = self.text_encoder([input_prompt], device)
690
  context_null = self.text_encoder([n_prompt], device)
691
+ # self.text_encoder.model.cpu()
692
 
693
  noise = [
694
  torch.randn(
 
706
  yield
707
 
708
  no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
 
709
 
710
  # evaluation mode
711
+ with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync():
712
 
713
  if sample_solver == 'unipc':
714
  sample_scheduler = FlowUniPCMultistepScheduler(
 
728
  arg_null = {'context': context_null, 'seq_len': seq_len, 'audio': audio_emb_neg}
729
 
730
  torch.cuda.empty_cache()
731
+
732
  for _, t in enumerate(tqdm(timesteps)):
733
  timestep = [t]
734
  timestep = torch.stack(timestep)
 
755
  x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
756
 
757
  # if offload_model:
758
+ # self.dit.cpu()
759
+
760
  torch.cuda.empty_cache()
761
  # if get_local_rank() == 0:
762
+ # self.vae.model.to(device=device)
763
  videos = self.vae.decode(x0)
764
+ # self.vae.model.to(device="cpu")
765
 
766
  del noise, latents, noise_pred
767
  del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
 
776
  return videos[0] # if get_local_rank() == 0 else None
777
 
778
 
779
+ def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0):
 
780
 
781
  video = self.inference(
782
  prompt,
 
787
  shift=self.config.diffusion.timesteps.sampling.shift,
788
  sample_solver='unipc',
789
  sampling_steps=steps,
790
+ inference_mode = inference_mode,
791
+ tea_cache_l1_thresh = tea_cache_l1_thresh,
792
+ seed=seed
793
  )
794
 
795
  torch.cuda.empty_cache()
796
  gc.collect()
797
 
 
798
  # Save samples.
799
  if get_sequence_parallel_rank() == 0:
800
  pathname = self.save_sample(
 
808
  del video, prompt
809
  torch.cuda.empty_cache()
810
  gc.collect()
 
811
 
812
 
813
  def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
 
846
  raise NotImplementedError
847
  assert isinstance(pos_prompts, ListConfig)
848
 
849
+ return pos_prompts
850
+
851
+ class TeaCache:
852
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
853
+ self.num_inference_steps = num_inference_steps
854
+ self.step = 0
855
+ self.accumulated_rel_l1_distance = 0
856
+ self.previous_modulated_input = None
857
+ self.rel_l1_thresh = rel_l1_thresh
858
+ self.previous_residual = None
859
+ self.previous_hidden_states = None
860
+
861
+ self.coefficients_dict = {
862
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
863
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
864
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
865
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
866
+ }
867
+ if model_id not in self.coefficients_dict:
868
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
869
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
870
+ self.coefficients = self.coefficients_dict[model_id]
871
+
872
+ def check(self, dit, x, t_mod):
873
+ modulated_inp = t_mod.clone()
874
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
875
+ should_calc = True
876
+ self.accumulated_rel_l1_distance = 0
877
+ else:
878
+ coefficients = self.coefficients
879
+ rescale_func = np.poly1d(coefficients)
880
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
881
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
882
+ should_calc = False
883
+ else:
884
+ should_calc = True
885
+ self.accumulated_rel_l1_distance = 0
886
+ self.previous_modulated_input = modulated_inp
887
+ self.step += 1
888
+ if self.step == self.num_inference_steps:
889
+ self.step = 0
890
+ if should_calc:
891
+ self.previous_hidden_states = x.clone()
892
+ return not should_calc
893
+
894
+ def store(self, hidden_states):
895
+ if self.previous_hidden_states is None:
896
+ return
897
+ self.previous_residual = hidden_states - self.previous_hidden_states
898
+ self.previous_hidden_states = None
899
+
900
+ def update(self, hidden_states):
901
+ hidden_states = hidden_states + self.previous_residual
902
+ return hidden_states