harry900000 commited on
Commit
09a9db5
·
1 Parent(s): 74308ee

move the inference function to `helper.py`

Browse files
Files changed (2) hide show
  1. app.py +9 -323
  2. helper.py +318 -4
app.py CHANGED
@@ -1,24 +1,6 @@
1
- import datetime
2
  import os
3
- import sys
4
- import tempfile
5
- import time
6
- import zipfile
7
- from typing import List, Tuple
8
-
9
- import gradio as gr
10
- import spaces
11
-
12
- from gpu_info import stop_watcher, watch_gpu_memory
13
-
14
- PWD = os.path.dirname(__file__)
15
- CHECKPOINTS_PATH = "/data/checkpoints"
16
- LOG_DIR = os.path.join(PWD, "logs")
17
- os.makedirs(LOG_DIR, exist_ok=True)
18
 
19
  try:
20
- import os
21
-
22
  from huggingface_hub import login
23
 
24
  # Try to login with token from environment variable
@@ -31,12 +13,8 @@ try:
31
  except Exception as e:
32
  print(f"Authentication failed: {e}")
33
 
34
- # download checkpoints
35
- from download_checkpoints import main as download_checkpoints
36
-
37
- os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
38
- download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
39
 
 
40
 
41
  from test_environment import main as check_environment
42
  from test_environment import setup_environment
@@ -46,314 +24,21 @@ setup_environment()
46
  # setup env
47
  os.environ["CUDA_HOME"] = "/usr/local/cuda"
48
  os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
49
- os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
50
 
51
  if not check_environment():
52
  sys.exit(1)
53
 
54
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning
55
-
56
- import copy
57
- import json
58
- import random
59
- from io import BytesIO
60
-
61
- import torch
62
-
63
- from cosmos_transfer1.checkpoints import (
64
- BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
65
- BASE_7B_CHECKPOINT_PATH,
66
- EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH,
67
- )
68
- from cosmos_transfer1.diffusion.inference.inference_utils import (
69
- validate_controlnet_specs,
70
- )
71
- from cosmos_transfer1.diffusion.inference.preprocessors import Preprocessors
72
- from cosmos_transfer1.diffusion.inference.world_generation_pipeline import (
73
- DiffusionControl2WorldGenerationPipeline,
74
- DistilledControl2WorldGenerationPipeline,
75
- )
76
- from cosmos_transfer1.utils import log, misc
77
- from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
78
- from helper import parse_arguments
79
-
80
- torch.enable_grad(False)
81
- torch.serialization.add_safe_globals([BytesIO])
82
-
83
-
84
- def inference(cfg, control_inputs, chunking) -> Tuple[List[str], List[str]]:
85
- video_paths = []
86
- prompt_paths = []
87
-
88
- control_inputs = validate_controlnet_specs(cfg, control_inputs)
89
- misc.set_random_seed(cfg.seed)
90
-
91
- device_rank = 0
92
- process_group = None
93
- if cfg.num_gpus > 1:
94
- from megatron.core import parallel_state
95
-
96
- from cosmos_transfer1.utils import distributed
97
-
98
- distributed.init()
99
- parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus)
100
- process_group = parallel_state.get_context_parallel_group()
101
-
102
- device_rank = distributed.get_rank(process_group)
103
-
104
- preprocessors = Preprocessors()
105
-
106
- if cfg.use_distilled:
107
- assert not cfg.is_av_sample
108
- checkpoint = EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH
109
- pipeline = DistilledControl2WorldGenerationPipeline(
110
- checkpoint_dir=cfg.checkpoint_dir,
111
- checkpoint_name=checkpoint,
112
- offload_network=cfg.offload_diffusion_transformer,
113
- offload_text_encoder_model=cfg.offload_text_encoder_model,
114
- offload_guardrail_models=cfg.offload_guardrail_models,
115
- guidance=cfg.guidance,
116
- num_steps=cfg.num_steps,
117
- fps=cfg.fps,
118
- seed=cfg.seed,
119
- num_input_frames=cfg.num_input_frames,
120
- control_inputs=control_inputs,
121
- sigma_max=cfg.sigma_max,
122
- blur_strength=cfg.blur_strength,
123
- canny_threshold=cfg.canny_threshold,
124
- upsample_prompt=cfg.upsample_prompt,
125
- offload_prompt_upsampler=cfg.offload_prompt_upsampler,
126
- process_group=process_group,
127
- )
128
- else:
129
- checkpoint = BASE_7B_CHECKPOINT_AV_SAMPLE_PATH if cfg.is_av_sample else BASE_7B_CHECKPOINT_PATH
130
-
131
- # Initialize transfer generation model pipeline
132
- pipeline = DiffusionControl2WorldGenerationPipeline(
133
- checkpoint_dir=cfg.checkpoint_dir,
134
- checkpoint_name=checkpoint,
135
- offload_network=cfg.offload_diffusion_transformer,
136
- offload_text_encoder_model=cfg.offload_text_encoder_model,
137
- offload_guardrail_models=cfg.offload_guardrail_models,
138
- guidance=cfg.guidance,
139
- num_steps=cfg.num_steps,
140
- fps=cfg.fps,
141
- seed=cfg.seed,
142
- num_input_frames=cfg.num_input_frames,
143
- control_inputs=control_inputs,
144
- sigma_max=cfg.sigma_max,
145
- blur_strength=cfg.blur_strength,
146
- canny_threshold=cfg.canny_threshold,
147
- upsample_prompt=cfg.upsample_prompt,
148
- offload_prompt_upsampler=cfg.offload_prompt_upsampler,
149
- process_group=process_group,
150
- chunking=chunking,
151
- )
152
-
153
- if cfg.batch_input_path:
154
- log.info(f"Reading batch inputs from path: {cfg.batch_input_path}")
155
- prompts = read_prompts_from_file(cfg.batch_input_path)
156
- else:
157
- # Single prompt case
158
- prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_video_path}]
159
-
160
- batch_size = cfg.batch_size if hasattr(cfg, "batch_size") else 1
161
- if any("upscale" in control_input for control_input in control_inputs) and batch_size > 1:
162
- batch_size = 1
163
- log.info("Setting batch_size=1 as upscale does not support batch generation")
164
- os.makedirs(cfg.video_save_folder, exist_ok=True)
165
- for batch_start in range(0, len(prompts), batch_size):
166
- # Get current batch
167
- batch_prompts = prompts[batch_start : batch_start + batch_size]
168
- actual_batch_size = len(batch_prompts)
169
- # Extract batch data
170
- batch_prompt_texts = [p.get("prompt", None) for p in batch_prompts]
171
- batch_video_paths = [p.get("visual_input", None) for p in batch_prompts]
172
-
173
- batch_control_inputs = []
174
- for i, input_dict in enumerate(batch_prompts):
175
- current_prompt = input_dict.get("prompt", None)
176
- current_video_path = input_dict.get("visual_input", None)
177
-
178
- if cfg.batch_input_path:
179
- video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}")
180
- os.makedirs(video_save_subfolder, exist_ok=True)
181
- else:
182
- video_save_subfolder = cfg.video_save_folder
183
-
184
- current_control_inputs = copy.deepcopy(control_inputs)
185
- if "control_overrides" in input_dict:
186
- for hint_key, override in input_dict["control_overrides"].items():
187
- if hint_key in current_control_inputs:
188
- current_control_inputs[hint_key].update(override)
189
- else:
190
- log.warning(f"Ignoring unknown control key in override: {hint_key}")
191
 
192
- # if control inputs are not provided, run respective preprocessor (for seg and depth)
193
- log.info("running preprocessor")
194
- preprocessors(
195
- current_video_path,
196
- current_prompt,
197
- current_control_inputs,
198
- video_save_subfolder,
199
- cfg.regional_prompts if hasattr(cfg, "regional_prompts") else None,
200
- )
201
- batch_control_inputs.append(current_control_inputs)
202
-
203
- regional_prompts = []
204
- region_definitions = []
205
- if hasattr(cfg, "regional_prompts") and cfg.regional_prompts:
206
- log.info(f"regional_prompts: {cfg.regional_prompts}")
207
- for regional_prompt in cfg.regional_prompts:
208
- regional_prompts.append(regional_prompt["prompt"])
209
- if "region_definitions_path" in regional_prompt:
210
- log.info(f"region_definitions_path: {regional_prompt['region_definitions_path']}")
211
- region_definition_path = regional_prompt["region_definitions_path"]
212
- if isinstance(region_definition_path, str) and region_definition_path.endswith(".json"):
213
- with open(region_definition_path, "r") as f:
214
- region_definitions_json = json.load(f)
215
- region_definitions.extend(region_definitions_json)
216
- else:
217
- region_definitions.append(region_definition_path)
218
-
219
- if hasattr(pipeline, "regional_prompts"):
220
- pipeline.regional_prompts = regional_prompts
221
- if hasattr(pipeline, "region_definitions"):
222
- pipeline.region_definitions = region_definitions
223
-
224
- # Generate videos in batch
225
- batch_outputs = pipeline.generate(
226
- prompt=batch_prompt_texts,
227
- video_path=batch_video_paths,
228
- negative_prompt=cfg.negative_prompt,
229
- control_inputs=batch_control_inputs,
230
- save_folder=video_save_subfolder,
231
- batch_size=actual_batch_size,
232
- )
233
- if batch_outputs is None:
234
- log.critical("Guardrail blocked generation for entire batch.")
235
- continue
236
-
237
- videos, final_prompts = batch_outputs
238
- for i, (video, prompt) in enumerate(zip(videos, final_prompts)):
239
- if cfg.batch_input_path:
240
- video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}")
241
- video_save_path = os.path.join(video_save_subfolder, "output.mp4")
242
- prompt_save_path = os.path.join(video_save_subfolder, "prompt.txt")
243
- else:
244
- video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
245
- prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
246
- # Save video and prompt
247
- if device_rank == 0:
248
- os.makedirs(os.path.dirname(video_save_path), exist_ok=True)
249
- save_video(
250
- video=video,
251
- fps=cfg.fps,
252
- H=video.shape[1],
253
- W=video.shape[2],
254
- video_save_quality=5,
255
- video_save_path=video_save_path,
256
- )
257
- video_paths.append(video_save_path)
258
-
259
- # Save prompt to text file alongside video
260
- with open(prompt_save_path, "wb") as f:
261
- f.write(prompt.encode("utf-8"))
262
-
263
- prompt_paths.append(prompt_save_path)
264
-
265
- log.info(f"Saved video to {video_save_path}")
266
- log.info(f"Saved prompt to {prompt_save_path}")
267
-
268
- # clean up properly
269
- if cfg.num_gpus > 1:
270
- parallel_state.destroy_model_parallel()
271
- import torch.distributed as dist
272
-
273
- dist.destroy_process_group()
274
-
275
- return video_paths, prompt_paths
276
-
277
-
278
- def create_zip_for_download(filename, files_to_zip):
279
- temp_dir = tempfile.mkdtemp()
280
- zip_path = os.path.join(temp_dir, f"{os.path.splitext(filename)[0]}.zip")
281
-
282
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
283
- for file_path in files_to_zip:
284
- arcname = os.path.basename(file_path)
285
- zipf.write(file_path, arcname)
286
-
287
- return zip_path
288
-
289
-
290
- @spaces.GPU()
291
- def generate_video(
292
- rgb_video_path,
293
- hdmap_video_input,
294
- lidar_video_input,
295
- prompt,
296
- negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
297
- seed=42,
298
- randomize_seed=False,
299
- chunking=None,
300
- progress=gr.Progress(track_tqdm=True),
301
- ):
302
- _dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
303
- logfile_path = os.path.join(LOG_DIR, f"{_dt}.log")
304
- log_handler = log.init_dev_loguru_file(logfile_path)
305
-
306
- if randomize_seed:
307
- actual_seed = random.randint(0, 1000000)
308
- else:
309
- actual_seed = seed
310
-
311
- log.info(f"actual_seed: {actual_seed}")
312
-
313
- if rgb_video_path is None or not os.path.isfile(rgb_video_path):
314
- log.warning(f"File `{rgb_video_path}` does not exist")
315
- rgb_video_path = ""
316
-
317
- # add timer to calculate the generation time
318
- start_time = time.time()
319
-
320
- # parse generation configs
321
- args, control_inputs = parse_arguments(
322
- controlnet_specs_in={
323
- "hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
324
- "lidar": {"control_weight": 0.7, "input_control": lidar_video_input},
325
- },
326
- input_video_path=rgb_video_path,
327
- checkpoint_dir=CHECKPOINTS_PATH,
328
- prompt=prompt,
329
- negative_prompt=negative_prompt,
330
- sigma_max=80,
331
- offload_text_encoder_model=True,
332
- is_av_sample=True,
333
- num_gpus=1,
334
- seed=seed,
335
- )
336
-
337
- # watch gpu memory
338
- watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
339
-
340
- # start inference
341
- if chunking <= 0:
342
- chunking = None
343
- videos, prompts = inference(args, control_inputs, chunking)
344
-
345
- # print the generation time
346
- end_time = time.time()
347
- log.info(f"Time taken: {end_time - start_time} s")
348
-
349
- # stop the watcher
350
- stop_watcher()
351
 
352
- video = videos[0]
 
 
353
 
354
- log.logger.remove(log_handler)
355
- return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed
356
 
 
357
 
358
  # Define the Gradio Blocks interface
359
  with gr.Blocks() as demo:
@@ -412,3 +97,4 @@ with gr.Blocks() as demo:
412
 
413
  if __name__ == "__main__":
414
  demo.launch()
 
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  try:
 
 
4
  from huggingface_hub import login
5
 
6
  # Try to login with token from environment variable
 
13
  except Exception as e:
14
  print(f"Authentication failed: {e}")
15
 
 
 
 
 
 
16
 
17
+ import sys
18
 
19
  from test_environment import main as check_environment
20
  from test_environment import setup_environment
 
24
  # setup env
25
  os.environ["CUDA_HOME"] = "/usr/local/cuda"
26
  os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
27
+ os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
28
 
29
  if not check_environment():
30
  sys.exit(1)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ from download_checkpoints import main as download_checkpoints
34
+ from helper import CHECKPOINTS_PATH, generate_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # download checkpoints
37
+ os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
38
+ download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
39
 
 
 
40
 
41
+ import gradio as gr
42
 
43
  # Define the Gradio Blocks interface
44
  with gr.Blocks() as demo:
 
97
 
98
  if __name__ == "__main__":
99
  demo.launch()
100
+ # demo.launch(server_name="0.0.0.0")
helper.py CHANGED
@@ -1,10 +1,45 @@
1
  import argparse
 
 
 
 
 
2
  import sys
3
- from typing import Any, Dict, Literal, Optional
 
 
 
 
4
 
5
- sys.path.append("./cosmos-transfer1")
6
 
7
- from cosmos_transfer1.diffusion.inference.inference_utils import valid_hint_keys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def load_controlnet_specs(controlnet_specs_in: dict) -> Dict[str, Any]:
@@ -62,7 +97,8 @@ def parse_arguments(
62
  :param int num_input_frames: Number of conditional frames for long video generation
63
  :param float sigma_max: sigma_max for partial denoising
64
  :param str blur_strength: blur strength
65
- :param str canny_threshold: blur strength of canny threshold applied to input. Lower means less blur or more detected edges, which means higher fidelity to input
 
66
  :param bool is_av_sample: Whether the model is an driving post-training model
67
  :param str checkpoint_dir: Base directory containing model checkpoints
68
  :param str tokenizer_dir: Tokenizer weights directory relative to checkpoint_dir
@@ -121,3 +157,281 @@ def parse_arguments(
121
  setattr(cmd_args, key, json_args[key])
122
 
123
  return cmd_args, control_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
+ import copy
3
+ import datetime
4
+ import json
5
+ import os
6
+ import random
7
  import sys
8
+ import tempfile
9
+ import time
10
+ import zipfile
11
+ from io import BytesIO
12
+ from typing import Any, Dict, List, Literal, Optional, Tuple
13
 
14
+ import torch
15
 
16
+ from cosmos_transfer1.checkpoints import (
17
+ BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
18
+ BASE_7B_CHECKPOINT_PATH,
19
+ EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH,
20
+ )
21
+ from cosmos_transfer1.diffusion.inference.inference_utils import (
22
+ valid_hint_keys,
23
+ validate_controlnet_specs,
24
+ )
25
+ from cosmos_transfer1.diffusion.inference.preprocessors import Preprocessors
26
+ from cosmos_transfer1.diffusion.inference.world_generation_pipeline import (
27
+ DiffusionControl2WorldGenerationPipeline,
28
+ DistilledControl2WorldGenerationPipeline,
29
+ )
30
+ from cosmos_transfer1.utils import log, misc
31
+ from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
32
+ from gpu_info import stop_watcher, watch_gpu_memory
33
+
34
+ PWD = os.path.dirname(__file__)
35
+ CHECKPOINTS_PATH = "/data/checkpoints"
36
+ LOG_DIR = os.path.join(PWD, "logs")
37
+ os.makedirs(LOG_DIR, exist_ok=True)
38
+
39
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning
40
+
41
+ torch.enable_grad(False)
42
+ torch.serialization.add_safe_globals([BytesIO])
43
 
44
 
45
  def load_controlnet_specs(controlnet_specs_in: dict) -> Dict[str, Any]:
 
97
  :param int num_input_frames: Number of conditional frames for long video generation
98
  :param float sigma_max: sigma_max for partial denoising
99
  :param str blur_strength: blur strength
100
+ :param str canny_threshold: blur strength of canny threshold applied to input. Lower means less blur or more detected edges,
101
+ which means higher fidelity to input
102
  :param bool is_av_sample: Whether the model is an driving post-training model
103
  :param str checkpoint_dir: Base directory containing model checkpoints
104
  :param str tokenizer_dir: Tokenizer weights directory relative to checkpoint_dir
 
157
  setattr(cmd_args, key, json_args[key])
158
 
159
  return cmd_args, control_inputs
160
+
161
+
162
+ def inference(cfg, control_inputs, chunking) -> Tuple[List[str], List[str]]:
163
+ video_paths = []
164
+ prompt_paths = []
165
+
166
+ control_inputs = validate_controlnet_specs(cfg, control_inputs)
167
+ misc.set_random_seed(cfg.seed)
168
+
169
+ device_rank = 0
170
+ process_group = None
171
+ if cfg.num_gpus > 1:
172
+ from megatron.core import (
173
+ parallel_state, # pyright: ignore[reportMissingImports]
174
+ )
175
+
176
+ from cosmos_transfer1.utils import distributed
177
+
178
+ distributed.init()
179
+ parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus)
180
+ process_group = parallel_state.get_context_parallel_group()
181
+
182
+ device_rank = distributed.get_rank(process_group)
183
+
184
+ preprocessors = Preprocessors()
185
+
186
+ if cfg.use_distilled:
187
+ assert not cfg.is_av_sample
188
+ checkpoint = EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH
189
+ pipeline = DistilledControl2WorldGenerationPipeline(
190
+ checkpoint_dir=cfg.checkpoint_dir,
191
+ checkpoint_name=checkpoint,
192
+ offload_network=cfg.offload_diffusion_transformer,
193
+ offload_text_encoder_model=cfg.offload_text_encoder_model,
194
+ offload_guardrail_models=cfg.offload_guardrail_models,
195
+ guidance=cfg.guidance,
196
+ num_steps=cfg.num_steps,
197
+ fps=cfg.fps,
198
+ seed=cfg.seed,
199
+ num_input_frames=cfg.num_input_frames,
200
+ control_inputs=control_inputs,
201
+ sigma_max=cfg.sigma_max,
202
+ blur_strength=cfg.blur_strength,
203
+ canny_threshold=cfg.canny_threshold,
204
+ upsample_prompt=cfg.upsample_prompt,
205
+ offload_prompt_upsampler=cfg.offload_prompt_upsampler,
206
+ process_group=process_group,
207
+ )
208
+ else:
209
+ checkpoint = BASE_7B_CHECKPOINT_AV_SAMPLE_PATH if cfg.is_av_sample else BASE_7B_CHECKPOINT_PATH
210
+
211
+ # Initialize transfer generation model pipeline
212
+ pipeline = DiffusionControl2WorldGenerationPipeline(
213
+ checkpoint_dir=cfg.checkpoint_dir,
214
+ checkpoint_name=checkpoint,
215
+ offload_network=cfg.offload_diffusion_transformer,
216
+ offload_text_encoder_model=cfg.offload_text_encoder_model,
217
+ offload_guardrail_models=cfg.offload_guardrail_models,
218
+ guidance=cfg.guidance,
219
+ num_steps=cfg.num_steps,
220
+ fps=cfg.fps,
221
+ seed=cfg.seed,
222
+ num_input_frames=cfg.num_input_frames,
223
+ control_inputs=control_inputs,
224
+ sigma_max=cfg.sigma_max,
225
+ blur_strength=cfg.blur_strength,
226
+ canny_threshold=cfg.canny_threshold,
227
+ upsample_prompt=cfg.upsample_prompt,
228
+ offload_prompt_upsampler=cfg.offload_prompt_upsampler,
229
+ process_group=process_group,
230
+ chunking=chunking,
231
+ )
232
+
233
+ if cfg.batch_input_path:
234
+ log.info(f"Reading batch inputs from path: {cfg.batch_input_path}")
235
+ prompts = read_prompts_from_file(cfg.batch_input_path)
236
+ else:
237
+ # Single prompt case
238
+ prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_video_path}]
239
+
240
+ batch_size = cfg.batch_size if hasattr(cfg, "batch_size") else 1
241
+ if any("upscale" in control_input for control_input in control_inputs) and batch_size > 1:
242
+ batch_size = 1
243
+ log.info("Setting batch_size=1 as upscale does not support batch generation")
244
+ os.makedirs(cfg.video_save_folder, exist_ok=True)
245
+ for batch_start in range(0, len(prompts), batch_size):
246
+ # Get current batch
247
+ batch_prompts = prompts[batch_start : batch_start + batch_size]
248
+ actual_batch_size = len(batch_prompts)
249
+ # Extract batch data
250
+ batch_prompt_texts = [p.get("prompt", None) for p in batch_prompts]
251
+ batch_video_paths = [p.get("visual_input", None) for p in batch_prompts]
252
+
253
+ batch_control_inputs = []
254
+ for i, input_dict in enumerate(batch_prompts):
255
+ current_prompt = input_dict.get("prompt", None)
256
+ current_video_path = input_dict.get("visual_input", None)
257
+
258
+ if cfg.batch_input_path:
259
+ video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}")
260
+ os.makedirs(video_save_subfolder, exist_ok=True)
261
+ else:
262
+ video_save_subfolder = cfg.video_save_folder
263
+
264
+ current_control_inputs = copy.deepcopy(control_inputs)
265
+ if "control_overrides" in input_dict:
266
+ for hint_key, override in input_dict["control_overrides"].items():
267
+ if hint_key in current_control_inputs:
268
+ current_control_inputs[hint_key].update(override)
269
+ else:
270
+ log.warning(f"Ignoring unknown control key in override: {hint_key}")
271
+
272
+ # if control inputs are not provided, run respective preprocessor (for seg and depth)
273
+ log.info("running preprocessor")
274
+ preprocessors(
275
+ current_video_path,
276
+ current_prompt,
277
+ current_control_inputs,
278
+ video_save_subfolder,
279
+ cfg.regional_prompts if hasattr(cfg, "regional_prompts") else None,
280
+ )
281
+ batch_control_inputs.append(current_control_inputs)
282
+
283
+ regional_prompts = []
284
+ region_definitions = []
285
+ if hasattr(cfg, "regional_prompts") and cfg.regional_prompts:
286
+ log.info(f"regional_prompts: {cfg.regional_prompts}")
287
+ for regional_prompt in cfg.regional_prompts:
288
+ regional_prompts.append(regional_prompt["prompt"])
289
+ if "region_definitions_path" in regional_prompt:
290
+ log.info(f"region_definitions_path: {regional_prompt['region_definitions_path']}")
291
+ region_definition_path = regional_prompt["region_definitions_path"]
292
+ if isinstance(region_definition_path, str) and region_definition_path.endswith(".json"):
293
+ with open(region_definition_path, "r") as f:
294
+ region_definitions_json = json.load(f)
295
+ region_definitions.extend(region_definitions_json)
296
+ else:
297
+ region_definitions.append(region_definition_path)
298
+
299
+ if hasattr(pipeline, "regional_prompts"):
300
+ pipeline.regional_prompts = regional_prompts
301
+ if hasattr(pipeline, "region_definitions"):
302
+ pipeline.region_definitions = region_definitions
303
+
304
+ # Generate videos in batch
305
+ batch_outputs = pipeline.generate(
306
+ prompt=batch_prompt_texts,
307
+ video_path=batch_video_paths,
308
+ negative_prompt=cfg.negative_prompt,
309
+ control_inputs=batch_control_inputs,
310
+ save_folder=video_save_subfolder,
311
+ batch_size=actual_batch_size,
312
+ )
313
+ if batch_outputs is None:
314
+ log.critical("Guardrail blocked generation for entire batch.")
315
+ continue
316
+
317
+ videos, final_prompts = batch_outputs
318
+ for i, (video, prompt) in enumerate(zip(videos, final_prompts)):
319
+ if cfg.batch_input_path:
320
+ video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}")
321
+ video_save_path = os.path.join(video_save_subfolder, "output.mp4")
322
+ prompt_save_path = os.path.join(video_save_subfolder, "prompt.txt")
323
+ else:
324
+ video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
325
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
326
+ # Save video and prompt
327
+ if device_rank == 0:
328
+ os.makedirs(os.path.dirname(video_save_path), exist_ok=True)
329
+ save_video(
330
+ video=video,
331
+ fps=cfg.fps,
332
+ H=video.shape[1],
333
+ W=video.shape[2],
334
+ video_save_quality=5,
335
+ video_save_path=video_save_path,
336
+ )
337
+ video_paths.append(video_save_path)
338
+
339
+ # Save prompt to text file alongside video
340
+ with open(prompt_save_path, "wb") as f:
341
+ f.write(prompt.encode("utf-8"))
342
+
343
+ prompt_paths.append(prompt_save_path)
344
+
345
+ log.info(f"Saved video to {video_save_path}")
346
+ log.info(f"Saved prompt to {prompt_save_path}")
347
+
348
+ # clean up properly
349
+ if cfg.num_gpus > 1:
350
+ parallel_state.destroy_model_parallel()
351
+ import torch.distributed as dist
352
+
353
+ dist.destroy_process_group()
354
+
355
+ return video_paths, prompt_paths
356
+
357
+
358
+ def create_zip_for_download(filename, files_to_zip):
359
+ temp_dir = tempfile.mkdtemp()
360
+ zip_path = os.path.join(temp_dir, f"{os.path.splitext(filename)[0]}.zip")
361
+
362
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
363
+ for file_path in files_to_zip:
364
+ arcname = os.path.basename(file_path)
365
+ zipf.write(file_path, arcname)
366
+
367
+ return zip_path
368
+
369
+
370
+ import gradio as gr
371
+
372
+
373
+ def generate_video(
374
+ rgb_video_path,
375
+ hdmap_video_input,
376
+ lidar_video_input,
377
+ prompt,
378
+ negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
379
+ seed=42,
380
+ randomize_seed=False,
381
+ chunking=None,
382
+ progress=gr.Progress(track_tqdm=True),
383
+ ):
384
+ _dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
385
+ logfile_path = os.path.join(LOG_DIR, f"{_dt}.log")
386
+ log_handler = log.init_dev_loguru_file(logfile_path)
387
+
388
+ if randomize_seed:
389
+ actual_seed = random.randint(0, 1000000)
390
+ else:
391
+ actual_seed = seed
392
+
393
+ log.info(f"actual_seed: {actual_seed}")
394
+
395
+ if rgb_video_path is None or not os.path.isfile(rgb_video_path):
396
+ log.warning(f"File `{rgb_video_path}` does not exist")
397
+ rgb_video_path = ""
398
+
399
+ # add timer to calculate the generation time
400
+ start_time = time.time()
401
+
402
+ # parse generation configs
403
+ args, control_inputs = parse_arguments(
404
+ controlnet_specs_in={
405
+ "hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
406
+ "lidar": {"control_weight": 0.7, "input_control": lidar_video_input},
407
+ },
408
+ input_video_path=rgb_video_path,
409
+ checkpoint_dir=CHECKPOINTS_PATH,
410
+ prompt=prompt,
411
+ negative_prompt=negative_prompt,
412
+ sigma_max=80,
413
+ offload_text_encoder_model=True,
414
+ is_av_sample=True,
415
+ num_gpus=1,
416
+ seed=seed,
417
+ )
418
+
419
+ # watch gpu memory
420
+ watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
421
+
422
+ # start inference
423
+ if chunking <= 0:
424
+ chunking = None
425
+ videos, prompts = inference(args, control_inputs, chunking)
426
+
427
+ # print the generation time
428
+ end_time = time.time()
429
+ log.info(f"Time taken: {end_time - start_time} s")
430
+
431
+ # stop the watcher
432
+ stop_watcher()
433
+
434
+ video = videos[0]
435
+
436
+ log.logger.remove(log_handler)
437
+ return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed