EuuIia commited on
Commit
74d80f7
·
verified ·
1 Parent(s): 829e1b9

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -774
inference.py DELETED
@@ -1,774 +0,0 @@
1
- import argparse
2
- import os
3
- import random
4
- from datetime import datetime
5
- from pathlib import Path
6
- from diffusers.utils import logging
7
- from typing import Optional, List, Union
8
- import yaml
9
-
10
- import imageio
11
- import json
12
- import numpy as np
13
- import torch
14
- import cv2
15
- from safetensors import safe_open
16
- from PIL import Image
17
- from transformers import (
18
- T5EncoderModel,
19
- T5Tokenizer,
20
- AutoModelForCausalLM,
21
- AutoProcessor,
22
- AutoTokenizer,
23
- )
24
- from huggingface_hub import hf_hub_download
25
-
26
- from ltx_video.models.autoencoders.causal_video_autoencoder import (
27
- CausalVideoAutoencoder,
28
- )
29
- from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
30
- from ltx_video.models.transformers.transformer3d import Transformer3DModel
31
- from ltx_video.pipelines.pipeline_ltx_video import (
32
- ConditioningItem,
33
- LTXVideoPipeline,
34
- LTXMultiScalePipeline,
35
- )
36
- from ltx_video.schedulers.rf import RectifiedFlowScheduler
37
- from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
38
- from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
39
- import ltx_video.pipelines.crf_compressor as crf_compressor
40
-
41
- MAX_HEIGHT = 720
42
- MAX_WIDTH = 1280
43
- MAX_NUM_FRAMES = 257
44
-
45
- logger = logging.get_logger("LTX-Video")
46
-
47
-
48
- def get_total_gpu_memory():
49
- if torch.cuda.is_available():
50
- total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
51
- return total_memory
52
- return 0
53
-
54
-
55
- def get_device():
56
- if torch.cuda.is_available():
57
- return "cuda"
58
- elif torch.backends.mps.is_available():
59
- return "mps"
60
- return "cpu"
61
-
62
-
63
- def load_image_to_tensor_with_resize_and_crop(
64
- image_input: Union[str, Image.Image],
65
- target_height: int = 512,
66
- target_width: int = 768,
67
- just_crop: bool = False,
68
- ) -> torch.Tensor:
69
- """Load and process an image into a tensor.
70
-
71
- Args:
72
- image_input: Either a file path (str) or a PIL Image object
73
- target_height: Desired height of output tensor
74
- target_width: Desired width of output tensor
75
- just_crop: If True, only crop the image to the target size without resizing
76
- """
77
- if isinstance(image_input, str):
78
- image = Image.open(image_input).convert("RGB")
79
- elif isinstance(image_input, Image.Image):
80
- image = image_input
81
- else:
82
- raise ValueError("image_input must be either a file path or a PIL Image object")
83
-
84
- input_width, input_height = image.size
85
- aspect_ratio_target = target_width / target_height
86
- aspect_ratio_frame = input_width / input_height
87
- if aspect_ratio_frame > aspect_ratio_target:
88
- new_width = int(input_height * aspect_ratio_target)
89
- new_height = input_height
90
- x_start = (input_width - new_width) // 2
91
- y_start = 0
92
- else:
93
- new_width = input_width
94
- new_height = int(input_width / aspect_ratio_target)
95
- x_start = 0
96
- y_start = (input_height - new_height) // 2
97
-
98
- image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
99
- if not just_crop:
100
- image = image.resize((target_width, target_height))
101
-
102
- image = np.array(image)
103
- image = cv2.GaussianBlur(image, (3, 3), 0)
104
- frame_tensor = torch.from_numpy(image).float()
105
- frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
106
- frame_tensor = frame_tensor.permute(2, 0, 1)
107
- frame_tensor = (frame_tensor / 127.5) - 1.0
108
- # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
109
- return frame_tensor.unsqueeze(0).unsqueeze(2)
110
-
111
-
112
- def calculate_padding(
113
- source_height: int, source_width: int, target_height: int, target_width: int
114
- ) -> tuple[int, int, int, int]:
115
-
116
- # Calculate total padding needed
117
- pad_height = target_height - source_height
118
- pad_width = target_width - source_width
119
-
120
- # Calculate padding for each side
121
- pad_top = pad_height // 2
122
- pad_bottom = pad_height - pad_top # Handles odd padding
123
- pad_left = pad_width // 2
124
- pad_right = pad_width - pad_left # Handles odd padding
125
-
126
- # Return padded tensor
127
- # Padding format is (left, right, top, bottom)
128
- padding = (pad_left, pad_right, pad_top, pad_bottom)
129
- return padding
130
-
131
-
132
- def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
133
- # Remove non-letters and convert to lowercase
134
- clean_text = "".join(
135
- char.lower() for char in text if char.isalpha() or char.isspace()
136
- )
137
-
138
- # Split into words
139
- words = clean_text.split()
140
-
141
- # Build result string keeping track of length
142
- result = []
143
- current_length = 0
144
-
145
- for word in words:
146
- # Add word length plus 1 for underscore (except for first word)
147
- new_length = current_length + len(word)
148
-
149
- if new_length <= max_len:
150
- result.append(word)
151
- current_length += len(word)
152
- else:
153
- break
154
-
155
- return "-".join(result)
156
-
157
-
158
- # Generate output video name
159
- def get_unique_filename(
160
- base: str,
161
- ext: str,
162
- prompt: str,
163
- seed: int,
164
- resolution: tuple[int, int, int],
165
- dir: Path,
166
- endswith=None,
167
- index_range=1000,
168
- ) -> Path:
169
- base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
170
- for i in range(index_range):
171
- filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
172
- if not os.path.exists(filename):
173
- return filename
174
- raise FileExistsError(
175
- f"Could not find a unique filename after {index_range} attempts."
176
- )
177
-
178
-
179
- def seed_everething(seed: int):
180
- random.seed(seed)
181
- np.random.seed(seed)
182
- torch.manual_seed(seed)
183
- if torch.cuda.is_available():
184
- torch.cuda.manual_seed(seed)
185
- if torch.backends.mps.is_available():
186
- torch.mps.manual_seed(seed)
187
-
188
-
189
- def main():
190
- parser = argparse.ArgumentParser(
191
- description="Load models from separate directories and run the pipeline."
192
- )
193
-
194
- # Directories
195
- parser.add_argument(
196
- "--output_path",
197
- type=str,
198
- default=None,
199
- help="Path to the folder to save output video, if None will save in outputs/ directory.",
200
- )
201
- parser.add_argument("--seed", type=int, default="171198")
202
-
203
- # Pipeline parameters
204
- parser.add_argument(
205
- "--num_images_per_prompt",
206
- type=int,
207
- default=1,
208
- help="Number of images per prompt",
209
- )
210
- parser.add_argument(
211
- "--image_cond_noise_scale",
212
- type=float,
213
- default=0.15,
214
- help="Amount of noise to add to the conditioned image",
215
- )
216
- parser.add_argument(
217
- "--height",
218
- type=int,
219
- default=704,
220
- help="Height of the output video frames. Optional if an input image provided.",
221
- )
222
- parser.add_argument(
223
- "--width",
224
- type=int,
225
- default=1216,
226
- help="Width of the output video frames. If None will infer from input image.",
227
- )
228
- parser.add_argument(
229
- "--num_frames",
230
- type=int,
231
- default=121,
232
- help="Number of frames to generate in the output video",
233
- )
234
- parser.add_argument(
235
- "--frame_rate", type=int, default=30, help="Frame rate for the output video"
236
- )
237
- parser.add_argument(
238
- "--device",
239
- default=None,
240
- help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
241
- )
242
- parser.add_argument(
243
- "--pipeline_config",
244
- type=str,
245
- default="configs/ltxv-13b-0.9.7-dev.yaml",
246
- help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
247
- )
248
-
249
- # Prompts
250
- parser.add_argument(
251
- "--prompt",
252
- type=str,
253
- help="Text prompt to guide generation",
254
- )
255
- parser.add_argument(
256
- "--negative_prompt",
257
- type=str,
258
- default="worst quality, inconsistent motion, blurry, jittery, distorted",
259
- help="Negative prompt for undesired features",
260
- )
261
-
262
- parser.add_argument(
263
- "--offload_to_cpu",
264
- action="store_true",
265
- help="Offloading unnecessary computations to CPU.",
266
- )
267
-
268
- # video-to-video arguments:
269
- parser.add_argument(
270
- "--input_media_path",
271
- type=str,
272
- default=None,
273
- help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
274
- )
275
-
276
- # Conditioning arguments
277
- parser.add_argument(
278
- "--conditioning_media_paths",
279
- type=str,
280
- nargs="*",
281
- help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
282
- )
283
- parser.add_argument(
284
- "--conditioning_strengths",
285
- type=float,
286
- nargs="*",
287
- help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
288
- )
289
- parser.add_argument(
290
- "--conditioning_start_frames",
291
- type=int,
292
- nargs="*",
293
- help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
294
- )
295
-
296
- args = parser.parse_args()
297
- logger.warning(f"Running generation with arguments: {args}")
298
- infer(**vars(args))
299
-
300
-
301
- def create_ltx_video_pipeline(
302
- ckpt_path: str,
303
- precision: str,
304
- text_encoder_model_name_or_path: str,
305
- sampler: Optional[str] = None,
306
- device: Optional[str] = None,
307
- enhance_prompt: bool = False,
308
- prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
309
- prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
310
- ) -> LTXVideoPipeline:
311
- ckpt_path = Path(ckpt_path)
312
- assert os.path.exists(
313
- ckpt_path
314
- ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
315
-
316
- with safe_open(ckpt_path, framework="pt") as f:
317
- metadata = f.metadata()
318
- config_str = metadata.get("config")
319
- configs = json.loads(config_str)
320
- allowed_inference_steps = configs.get("allowed_inference_steps", None)
321
-
322
- vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
323
- transformer = Transformer3DModel.from_pretrained(ckpt_path)
324
-
325
- # Use constructor if sampler is specified, otherwise use from_pretrained
326
- if sampler == "from_checkpoint" or not sampler:
327
- scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
328
- else:
329
- scheduler = RectifiedFlowScheduler(
330
- sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
331
- )
332
-
333
- text_encoder = T5EncoderModel.from_pretrained(
334
- text_encoder_model_name_or_path, subfolder="text_encoder"
335
- )
336
- patchifier = SymmetricPatchifier(patch_size=1)
337
- tokenizer = T5Tokenizer.from_pretrained(
338
- text_encoder_model_name_or_path, subfolder="tokenizer"
339
- )
340
-
341
- transformer = transformer.to(device)
342
- vae = vae.to(device)
343
- text_encoder = text_encoder.to(device)
344
-
345
- if enhance_prompt:
346
- prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
347
- prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
348
- )
349
- prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
350
- prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
351
- )
352
- prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
353
- prompt_enhancer_llm_model_name_or_path,
354
- torch_dtype="bfloat16",
355
- )
356
- prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
357
- prompt_enhancer_llm_model_name_or_path,
358
- )
359
- else:
360
- prompt_enhancer_image_caption_model = None
361
- prompt_enhancer_image_caption_processor = None
362
- prompt_enhancer_llm_model = None
363
- prompt_enhancer_llm_tokenizer = None
364
-
365
- vae = vae.to(torch.bfloat16)
366
- if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
367
- transformer = transformer.to(torch.bfloat16)
368
- text_encoder = text_encoder.to(torch.bfloat16)
369
-
370
- # Use submodels for the pipeline
371
- submodel_dict = {
372
- "transformer": transformer,
373
- "patchifier": patchifier,
374
- "text_encoder": text_encoder,
375
- "tokenizer": tokenizer,
376
- "scheduler": scheduler,
377
- "vae": vae,
378
- "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
379
- "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
380
- "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
381
- "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
382
- "allowed_inference_steps": allowed_inference_steps,
383
- }
384
-
385
- pipeline = LTXVideoPipeline(**submodel_dict)
386
- pipeline = pipeline.to(device)
387
- return pipeline
388
-
389
-
390
- def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
391
- latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
392
- latent_upsampler.to(device)
393
- latent_upsampler.eval()
394
- return latent_upsampler
395
-
396
-
397
- def infer(
398
- output_path: Optional[str],
399
- seed: int,
400
- pipeline_config: str,
401
- image_cond_noise_scale: float,
402
- height: Optional[int],
403
- width: Optional[int],
404
- num_frames: int,
405
- frame_rate: int,
406
- prompt: str,
407
- negative_prompt: str,
408
- offload_to_cpu: bool,
409
- input_media_path: Optional[str] = None,
410
- conditioning_media_paths: Optional[List[str]] = None,
411
- conditioning_strengths: Optional[List[float]] = None,
412
- conditioning_start_frames: Optional[List[int]] = None,
413
- device: Optional[str] = None,
414
- **kwargs,
415
- ):
416
- # check if pipeline_config is a file
417
- if not os.path.isfile(pipeline_config):
418
- raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
419
- with open(pipeline_config, "r") as f:
420
- pipeline_config = yaml.safe_load(f)
421
-
422
- models_dir = "MODEL_DIR"
423
-
424
- ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
425
- if not os.path.isfile(ltxv_model_name_or_path):
426
- ltxv_model_path = hf_hub_download(
427
- repo_id="Lightricks/LTX-Video",
428
- filename=ltxv_model_name_or_path,
429
- local_dir=models_dir,
430
- repo_type="model",
431
- )
432
- else:
433
- ltxv_model_path = ltxv_model_name_or_path
434
-
435
- spatial_upscaler_model_name_or_path = pipeline_config.get(
436
- "spatial_upscaler_model_path"
437
- )
438
- if spatial_upscaler_model_name_or_path and not os.path.isfile(
439
- spatial_upscaler_model_name_or_path
440
- ):
441
- spatial_upscaler_model_path = hf_hub_download(
442
- repo_id="Lightricks/LTX-Video",
443
- filename=spatial_upscaler_model_name_or_path,
444
- local_dir=models_dir,
445
- repo_type="model",
446
- )
447
- else:
448
- spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
449
-
450
- if kwargs.get("input_image_path", None):
451
- logger.warning(
452
- "Please use conditioning_media_paths instead of input_image_path."
453
- )
454
- assert not conditioning_media_paths and not conditioning_start_frames
455
- conditioning_media_paths = [kwargs["input_image_path"]]
456
- conditioning_start_frames = [0]
457
-
458
- # Validate conditioning arguments
459
- if conditioning_media_paths:
460
- # Use default strengths of 1.0
461
- if not conditioning_strengths:
462
- conditioning_strengths = [1.0] * len(conditioning_media_paths)
463
- if not conditioning_start_frames:
464
- raise ValueError(
465
- "If `conditioning_media_paths` is provided, "
466
- "`conditioning_start_frames` must also be provided"
467
- )
468
- if len(conditioning_media_paths) != len(conditioning_strengths) or len(
469
- conditioning_media_paths
470
- ) != len(conditioning_start_frames):
471
- raise ValueError(
472
- "`conditioning_media_paths`, `conditioning_strengths`, "
473
- "and `conditioning_start_frames` must have the same length"
474
- )
475
- if any(s < 0 or s > 1 for s in conditioning_strengths):
476
- raise ValueError("All conditioning strengths must be between 0 and 1")
477
- if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
478
- raise ValueError(
479
- f"All conditioning start frames must be between 0 and {num_frames-1}"
480
- )
481
-
482
- seed_everething(seed)
483
- if offload_to_cpu and not torch.cuda.is_available():
484
- logger.warning(
485
- "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
486
- )
487
- offload_to_cpu = False
488
- else:
489
- offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
490
-
491
- output_dir = (
492
- Path(output_path)
493
- if output_path
494
- else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
495
- )
496
- output_dir.mkdir(parents=True, exist_ok=True)
497
-
498
- # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
499
- height_padded = ((height - 1) // 32 + 1) * 32
500
- width_padded = ((width - 1) // 32 + 1) * 32
501
- num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
502
-
503
- padding = calculate_padding(height, width, height_padded, width_padded)
504
-
505
- logger.warning(
506
- f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
507
- )
508
-
509
- prompt_enhancement_words_threshold = pipeline_config[
510
- "prompt_enhancement_words_threshold"
511
- ]
512
-
513
- prompt_word_count = len(prompt.split())
514
- enhance_prompt = (
515
- prompt_enhancement_words_threshold > 0
516
- and prompt_word_count < prompt_enhancement_words_threshold
517
- )
518
-
519
- if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
520
- logger.info(
521
- f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
522
- )
523
-
524
- precision = pipeline_config["precision"]
525
- text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
526
- sampler = pipeline_config["sampler"]
527
- prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
528
- "prompt_enhancer_image_caption_model_name_or_path"
529
- ]
530
- prompt_enhancer_llm_model_name_or_path = pipeline_config[
531
- "prompt_enhancer_llm_model_name_or_path"
532
- ]
533
-
534
- pipeline = create_ltx_video_pipeline(
535
- ckpt_path=ltxv_model_path,
536
- precision=precision,
537
- text_encoder_model_name_or_path=text_encoder_model_name_or_path,
538
- sampler=sampler,
539
- device=kwargs.get("device", get_device()),
540
- enhance_prompt=enhance_prompt,
541
- prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
542
- prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
543
- )
544
-
545
- if pipeline_config.get("pipeline_type", None) == "multi-scale":
546
- if not spatial_upscaler_model_path:
547
- raise ValueError(
548
- "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
549
- )
550
- latent_upsampler = create_latent_upsampler(
551
- spatial_upscaler_model_path, pipeline.device
552
- )
553
- pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
554
-
555
- media_item = None
556
- if input_media_path:
557
- media_item = load_media_file(
558
- media_path=input_media_path,
559
- height=height,
560
- width=width,
561
- max_frames=num_frames_padded,
562
- padding=padding,
563
- )
564
-
565
- conditioning_items = (
566
- prepare_conditioning(
567
- conditioning_media_paths=conditioning_media_paths,
568
- conditioning_strengths=conditioning_strengths,
569
- conditioning_start_frames=conditioning_start_frames,
570
- height=height,
571
- width=width,
572
- num_frames=num_frames,
573
- padding=padding,
574
- pipeline=pipeline,
575
- )
576
- if conditioning_media_paths
577
- else None
578
- )
579
-
580
- stg_mode = pipeline_config.get("stg_mode", "attention_values")
581
- del pipeline_config["stg_mode"]
582
- if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
583
- skip_layer_strategy = SkipLayerStrategy.AttentionValues
584
- elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
585
- skip_layer_strategy = SkipLayerStrategy.AttentionSkip
586
- elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
587
- skip_layer_strategy = SkipLayerStrategy.Residual
588
- elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
589
- skip_layer_strategy = SkipLayerStrategy.TransformerBlock
590
- else:
591
- raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
592
-
593
- # Prepare input for the pipeline
594
- sample = {
595
- "prompt": prompt,
596
- "prompt_attention_mask": None,
597
- "negative_prompt": negative_prompt,
598
- "negative_prompt_attention_mask": None,
599
- }
600
-
601
- device = device or get_device()
602
- generator = torch.Generator(device=device).manual_seed(seed)
603
-
604
- images = pipeline(
605
- **pipeline_config,
606
- skip_layer_strategy=skip_layer_strategy,
607
- generator=generator,
608
- output_type="pt",
609
- callback_on_step_end=None,
610
- height=height_padded,
611
- width=width_padded,
612
- num_frames=num_frames_padded,
613
- frame_rate=frame_rate,
614
- **sample,
615
- media_items=media_item,
616
- conditioning_items=conditioning_items,
617
- is_video=True,
618
- vae_per_channel_normalize=True,
619
- image_cond_noise_scale=image_cond_noise_scale,
620
- mixed_precision=(precision == "mixed_precision"),
621
- offload_to_cpu=offload_to_cpu,
622
- device=device,
623
- enhance_prompt=enhance_prompt,
624
- ).images
625
-
626
- # Crop the padded images to the desired resolution and number of frames
627
- (pad_left, pad_right, pad_top, pad_bottom) = padding
628
- pad_bottom = -pad_bottom
629
- pad_right = -pad_right
630
- if pad_bottom == 0:
631
- pad_bottom = images.shape[3]
632
- if pad_right == 0:
633
- pad_right = images.shape[4]
634
- images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
635
-
636
- for i in range(images.shape[0]):
637
- # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
638
- video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
639
- # Unnormalizing images to [0, 255] range
640
- video_np = (video_np * 255).astype(np.uint8)
641
- fps = frame_rate
642
- height, width = video_np.shape[1:3]
643
- # In case a single image is generated
644
- if video_np.shape[0] == 1:
645
- output_filename = get_unique_filename(
646
- f"image_output_{i}",
647
- ".png",
648
- prompt=prompt,
649
- seed=seed,
650
- resolution=(height, width, num_frames),
651
- dir=output_dir,
652
- )
653
- imageio.imwrite(output_filename, video_np[0])
654
- else:
655
- output_filename = get_unique_filename(
656
- f"video_output_{i}",
657
- ".mp4",
658
- prompt=prompt,
659
- seed=seed,
660
- resolution=(height, width, num_frames),
661
- dir=output_dir,
662
- )
663
-
664
- # Write video
665
- with imageio.get_writer(output_filename, fps=fps) as video:
666
- for frame in video_np:
667
- video.append_data(frame)
668
-
669
- logger.warning(f"Output saved to {output_filename}")
670
-
671
-
672
- def prepare_conditioning(
673
- conditioning_media_paths: List[str],
674
- conditioning_strengths: List[float],
675
- conditioning_start_frames: List[int],
676
- height: int,
677
- width: int,
678
- num_frames: int,
679
- padding: tuple[int, int, int, int],
680
- pipeline: LTXVideoPipeline,
681
- ) -> Optional[List[ConditioningItem]]:
682
- """Prepare conditioning items based on input media paths and their parameters.
683
-
684
- Args:
685
- conditioning_media_paths: List of paths to conditioning media (images or videos)
686
- conditioning_strengths: List of conditioning strengths for each media item
687
- conditioning_start_frames: List of frame indices where each item should be applied
688
- height: Height of the output frames
689
- width: Width of the output frames
690
- num_frames: Number of frames in the output video
691
- padding: Padding to apply to the frames
692
- pipeline: LTXVideoPipeline object used for condition video trimming
693
-
694
- Returns:
695
- A list of ConditioningItem objects.
696
- """
697
- conditioning_items = []
698
- for path, strength, start_frame in zip(
699
- conditioning_media_paths, conditioning_strengths, conditioning_start_frames
700
- ):
701
- num_input_frames = orig_num_input_frames = get_media_num_frames(path)
702
- if hasattr(pipeline, "trim_conditioning_sequence") and callable(
703
- getattr(pipeline, "trim_conditioning_sequence")
704
- ):
705
- num_input_frames = pipeline.trim_conditioning_sequence(
706
- start_frame, orig_num_input_frames, num_frames
707
- )
708
- if num_input_frames < orig_num_input_frames:
709
- logger.warning(
710
- f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
711
- )
712
-
713
- media_tensor = load_media_file(
714
- media_path=path,
715
- height=height,
716
- width=width,
717
- max_frames=num_input_frames,
718
- padding=padding,
719
- just_crop=True,
720
- )
721
- conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
722
- return conditioning_items
723
-
724
-
725
- def get_media_num_frames(media_path: str) -> int:
726
- is_video = any(
727
- media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
728
- )
729
- num_frames = 1
730
- if is_video:
731
- reader = imageio.get_reader(media_path)
732
- num_frames = reader.count_frames()
733
- reader.close()
734
- return num_frames
735
-
736
-
737
- def load_media_file(
738
- media_path: str,
739
- height: int,
740
- width: int,
741
- max_frames: int,
742
- padding: tuple[int, int, int, int],
743
- just_crop: bool = False,
744
- ) -> torch.Tensor:
745
- is_video = any(
746
- media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
747
- )
748
- if is_video:
749
- reader = imageio.get_reader(media_path)
750
- num_input_frames = min(reader.count_frames(), max_frames)
751
-
752
- # Read and preprocess the relevant frames from the video file.
753
- frames = []
754
- for i in range(num_input_frames):
755
- frame = Image.fromarray(reader.get_data(i))
756
- frame_tensor = load_image_to_tensor_with_resize_and_crop(
757
- frame, height, width, just_crop=just_crop
758
- )
759
- frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
760
- frames.append(frame_tensor)
761
- reader.close()
762
-
763
- # Stack frames along the temporal dimension
764
- media_tensor = torch.cat(frames, dim=2)
765
- else: # Input image
766
- media_tensor = load_image_to_tensor_with_resize_and_crop(
767
- media_path, height, width, just_crop=just_crop
768
- )
769
- media_tensor = torch.nn.functional.pad(media_tensor, padding)
770
- return media_tensor
771
-
772
-
773
- if __name__ == "__main__":
774
- main()