|
|
import os |
|
|
import shutil |
|
|
import random |
|
|
import sys |
|
|
import tempfile |
|
|
from typing import Sequence, Mapping, Any, Union |
|
|
|
|
|
import spaces |
|
|
import torch |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
from comfy import model_management |
|
|
|
|
|
def hf_hub_download_local(repo_id, filename, local_dir, **kwargs): |
|
|
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs) |
|
|
os.makedirs(local_dir, exist_ok=True) |
|
|
base_filename = os.path.basename(filename) |
|
|
target_path = os.path.join(local_dir, base_filename) |
|
|
|
|
|
if os.path.exists(target_path) or os.path.islink(target_path): |
|
|
os.remove(target_path) |
|
|
|
|
|
os.symlink(downloaded_path, target_path) |
|
|
return target_path |
|
|
|
|
|
|
|
|
print("Downloading models from Hugging Face Hub...") |
|
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders") |
|
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") |
|
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") |
|
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae") |
|
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision") |
|
|
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras") |
|
|
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras") |
|
|
print("Downloads complete.") |
|
|
|
|
|
|
|
|
def calculate_video_dimensions(width, height, max_size=832, min_size=480): |
|
|
""" |
|
|
Calculate video dimensions based on input image size. |
|
|
Larger dimension becomes max_size, smaller becomes proportional. |
|
|
If square, use min_size x min_size. |
|
|
Results are rounded to nearest multiple of 16. |
|
|
""" |
|
|
|
|
|
if width == height: |
|
|
video_width = min_size |
|
|
video_height = min_size |
|
|
else: |
|
|
|
|
|
aspect_ratio = width / height |
|
|
|
|
|
if width > height: |
|
|
|
|
|
video_width = max_size |
|
|
video_height = int(max_size / aspect_ratio) |
|
|
else: |
|
|
|
|
|
video_height = max_size |
|
|
video_width = int(max_size * aspect_ratio) |
|
|
|
|
|
|
|
|
video_width = round(video_width / 16) * 16 |
|
|
video_height = round(video_height / 16) * 16 |
|
|
|
|
|
|
|
|
video_width = max(video_width, 16) |
|
|
video_height = max(video_height, 16) |
|
|
|
|
|
return video_width, video_height |
|
|
|
|
|
def resize_and_crop_to_match(target_image, reference_image): |
|
|
""" |
|
|
Resize and center crop target_image to match reference_image dimensions. |
|
|
""" |
|
|
ref_width, ref_height = reference_image.size |
|
|
target_width, target_height = target_image.size |
|
|
|
|
|
|
|
|
scale = max(ref_width / target_width, ref_height / target_height) |
|
|
|
|
|
|
|
|
new_width = int(target_width * scale) |
|
|
new_height = int(target_height * scale) |
|
|
resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
left = (new_width - ref_width) // 2 |
|
|
top = (new_height - ref_height) // 2 |
|
|
right = left + ref_width |
|
|
bottom = top + ref_height |
|
|
|
|
|
cropped = resized.crop((left, top, right, bottom)) |
|
|
return cropped |
|
|
|
|
|
|
|
|
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: |
|
|
"""Returns the value at the given index of a sequence or mapping. |
|
|
|
|
|
If the object is a sequence (like list or string), returns the value at the given index. |
|
|
If the object is a mapping (like a dictionary), returns the value at the index-th key. |
|
|
|
|
|
Some return a dictionary, in these cases, we look for the "results" key |
|
|
|
|
|
Args: |
|
|
obj (Union[Sequence, Mapping]): The object to retrieve the value from. |
|
|
index (int): The index of the value to retrieve. |
|
|
|
|
|
Returns: |
|
|
Any: The value at the given index. |
|
|
|
|
|
Raises: |
|
|
IndexError: If the index is out of bounds for the object and the object is not a mapping. |
|
|
""" |
|
|
try: |
|
|
return obj[index] |
|
|
except KeyError: |
|
|
|
|
|
if isinstance(obj, Mapping) and "result" in obj: |
|
|
return obj["result"][index] |
|
|
raise |
|
|
|
|
|
def find_path(name: str, path: str = None) -> str: |
|
|
""" |
|
|
Recursively looks at parent folders starting from the given path until it finds the given name. |
|
|
Returns the path as a Path object if found, or None otherwise. |
|
|
""" |
|
|
if path is None: |
|
|
path = os.getcwd() |
|
|
|
|
|
if name in os.listdir(path): |
|
|
path_name = os.path.join(path, name) |
|
|
print(f"'{name}' found: {path_name}") |
|
|
return path_name |
|
|
|
|
|
parent_directory = os.path.dirname(path) |
|
|
if parent_directory == path: |
|
|
return None |
|
|
|
|
|
return find_path(name, parent_directory) |
|
|
|
|
|
|
|
|
def add_comfyui_directory_to_sys_path() -> None: |
|
|
""" |
|
|
Add 'ComfyUI' to the sys.path |
|
|
""" |
|
|
comfyui_path = find_path("ComfyUI") |
|
|
if comfyui_path is not None and os.path.isdir(comfyui_path): |
|
|
sys.path.append(comfyui_path) |
|
|
print(f"'{comfyui_path}' added to sys.path") |
|
|
else: |
|
|
print("Could not find ComfyUI directory. Please run from a parent folder of ComfyUI.") |
|
|
|
|
|
def add_extra_model_paths() -> None: |
|
|
""" |
|
|
Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path. |
|
|
""" |
|
|
try: |
|
|
from main import load_extra_path_config |
|
|
except ImportError: |
|
|
print( |
|
|
"Could not import load_extra_path_config from main.py. This might be okay if you don't use it." |
|
|
) |
|
|
return |
|
|
|
|
|
extra_model_paths = find_path("extra_model_paths.yaml") |
|
|
if extra_model_paths is not None: |
|
|
load_extra_path_config(extra_model_paths) |
|
|
else: |
|
|
print("Could not find an optional 'extra_model_paths.yaml' config file.") |
|
|
|
|
|
def import_custom_nodes() -> None: |
|
|
"""Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS |
|
|
This function sets up a new asyncio event loop, initializes the PromptServer, |
|
|
creates a PromptQueue, and initializes the custom nodes. |
|
|
""" |
|
|
import asyncio |
|
|
import execution |
|
|
from nodes import init_extra_nodes |
|
|
import server |
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(loop) |
|
|
server_instance = server.PromptServer(loop) |
|
|
execution.PromptQueue(server_instance) |
|
|
loop.run_until_complete(init_extra_nodes(init_custom_nodes=True)) |
|
|
|
|
|
|
|
|
|
|
|
MODELS_AND_NODES = {} |
|
|
|
|
|
print("Setting up ComfyUI paths...") |
|
|
add_comfyui_directory_to_sys_path() |
|
|
add_extra_model_paths() |
|
|
|
|
|
print("Importing custom nodes...") |
|
|
import_custom_nodes() |
|
|
|
|
|
|
|
|
from nodes import NODE_CLASS_MAPPINGS |
|
|
global folder_paths |
|
|
import folder_paths |
|
|
|
|
|
print("Loading models into memory. This may take a few minutes...") |
|
|
|
|
|
|
|
|
cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]() |
|
|
MODELS_AND_NODES["clip"] = cliploader.load_clip( |
|
|
clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu" |
|
|
) |
|
|
|
|
|
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() |
|
|
unet_low_noise = unetloader.load_unet( |
|
|
unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", |
|
|
weight_dtype="default", |
|
|
) |
|
|
unet_high_noise = unetloader.load_unet( |
|
|
unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", |
|
|
weight_dtype="default", |
|
|
) |
|
|
|
|
|
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() |
|
|
MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors") |
|
|
|
|
|
|
|
|
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() |
|
|
MODELS_AND_NODES["model_low_noise"] = loraloadermodelonly.load_lora_model_only( |
|
|
lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", |
|
|
strength_model=0.8, |
|
|
model=get_value_at_index(unet_low_noise, 0), |
|
|
) |
|
|
MODELS_AND_NODES["model_high_noise"] = loraloadermodelonly.load_lora_model_only( |
|
|
lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", |
|
|
strength_model=0.8, |
|
|
model=get_value_at_index(unet_high_noise, 0), |
|
|
) |
|
|
|
|
|
|
|
|
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() |
|
|
MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip( |
|
|
clip_name="clip_vision_h.safetensors" |
|
|
) |
|
|
|
|
|
|
|
|
MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() |
|
|
MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]() |
|
|
MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]() |
|
|
MODELS_AND_NODES["ModelSamplingSD3"] = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]() |
|
|
MODELS_AND_NODES["PathchSageAttentionKJ"] = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]() |
|
|
MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]() |
|
|
MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]() |
|
|
MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]() |
|
|
MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]() |
|
|
MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]() |
|
|
|
|
|
print("Pre-loading main models onto GPU...") |
|
|
model_loaders = [ |
|
|
MODELS_AND_NODES["clip"], |
|
|
MODELS_AND_NODES["vae"], |
|
|
MODELS_AND_NODES["model_low_noise"], |
|
|
MODELS_AND_NODES["model_high_noise"], |
|
|
MODELS_AND_NODES["clip_vision"], |
|
|
] |
|
|
model_management.load_models_gpu([ |
|
|
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders |
|
|
]) |
|
|
print("All models loaded successfully!") |
|
|
import time |
|
|
import gradio as gr |
|
|
import tempfile |
|
|
import torch |
|
|
import random |
|
|
import spaces |
|
|
|
|
|
|
|
|
def get_duration( |
|
|
start_image_pil, |
|
|
end_image_pil, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
duration_seconds, |
|
|
progress, |
|
|
): |
|
|
|
|
|
calc_time = steps * 15 |
|
|
print(f"[GPU Duration Estimate] {calc_time} sec for {steps} steps") |
|
|
return min(calc_time, 300) |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=get_duration) |
|
|
def generate_video( |
|
|
start_image_pil, |
|
|
end_image_pil, |
|
|
prompt, |
|
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", |
|
|
duration_seconds=duration_seconds, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
""" |
|
|
The main function to generate a video based on user inputs. |
|
|
This function is called every time the user clicks the 'Generate' button. |
|
|
""" |
|
|
start_time = time.time() |
|
|
FPS = 16 |
|
|
duration = int(FPS * duration_seconds) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
print(f"[GPU Time Log] Video generated in {elapsed:.2f} sec") |
|
|
|
|
|
return f"output/{save_result['ui']['images'][0]['filename']}" |
|
|
|
|
|
|
|
|
|
|
|
css = ''' |
|
|
.fillable{max-width: 1100px !important} |
|
|
.dark .progress-text {color: white} |
|
|
''' |
|
|
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: |
|
|
gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") |
|
|
gr.Markdown("GPU time is dynamically calculated. Max video duration: **5 seconds**.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Group(): |
|
|
with gr.Row(): |
|
|
start_image = gr.Image(type="pil", label="Start Frame") |
|
|
end_image = gr.Image(type="pil", label="End Frame") |
|
|
|
|
|
prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") |
|
|
|
|
|
|
|
|
duration_seconds = gr.Slider( |
|
|
minimum=1, maximum=5, value=2, step=1, |
|
|
label="Video Duration (seconds)" |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False, visible=False): |
|
|
negative_prompt = gr.Textbox( |
|
|
label="Negative Prompt", |
|
|
value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
generate_button = gr.Button("Generate Video", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_video = gr.Video(label="Generated Video", autoplay=True) |
|
|
|
|
|
generate_button.click( |
|
|
fn=generate_video, |
|
|
inputs=[start_image, end_image, prompt, negative_prompt, duration_seconds], |
|
|
outputs=output_video |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["poli_tower.png", "tower_takes_off.png", "the man turns around"], |
|
|
["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"], |
|
|
["capyabara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"], |
|
|
], |
|
|
inputs=[start_image, end_image, prompt], |
|
|
outputs=output_video, |
|
|
fn=generate_video, |
|
|
cache_examples="lazy", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch(share=True) |
|
|
|