multimodalart's picture
Update app.py
211837b verified
raw
history blame
18.2 kB
import os
import shutil
import sys
import subprocess
import asyncio
import uuid
import random
import tempfile
from typing import Sequence, Mapping, Any, Union
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
import spaces
# --- 1. Model Download and Setup ---
def hf_hub_download_local(repo_id, filename, local_dir, **kwargs):
"""Downloads a file from Hugging Face Hub and symlinks it to a local directory."""
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
os.makedirs(local_dir, exist_ok=True)
base_filename = os.path.basename(filename)
target_path = os.path.join(local_dir, base_filename)
# Remove existing symlink or file to avoid errors
if os.path.exists(target_path) or os.path.islink(target_path):
os.remove(target_path)
os.symlink(downloaded_path, target_path)
return target_path
print("Downloading models from Hugging Face Hub...")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae")
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision")
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras")
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
print("Downloads complete.")
# --- 2. ComfyUI Backend Initialization ---
def find_path(name: str, path: str = None) -> str:
"""Recursively finds a directory with a given name."""
if path is None: path = os.getcwd()
if name in os.listdir(path): return os.path.join(path, name)
parent_directory = os.path.dirname(path)
return find_path(name, parent_directory) if parent_directory != path else None
def add_comfyui_directory_to_sys_path() -> None:
"""Adds the ComfyUI directory to sys.path for imports."""
comfyui_path = find_path("ComfyUI")
if comfyui_path and os.path.isdir(comfyui_path):
sys.path.append(comfyui_path)
print(f"'{comfyui_path}' added to sys.path")
def add_extra_model_paths() -> None:
"""Initializes ComfyUI's folder_paths with custom paths."""
from main import apply_custom_paths
apply_custom_paths()
def import_custom_nodes() -> None:
"""Initializes all ComfyUI custom nodes."""
import nodes
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(nodes.init_extra_nodes(init_custom_nodes=True))
print("Setting up ComfyUI paths and nodes...")
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
import_custom_nodes()
print("ComfyUI setup complete.")
# --- 3. Global Model & Node Loading and Patching ---
from nodes import NODE_CLASS_MAPPINGS
import folder_paths
from comfy import model_management
# Set VRAM mode to HIGH to prevent models from being offloaded from GPU after use.
# model_management.vram_state = model_management.VRAMState.HIGH_VRAM
MODELS_AND_NODES = {}
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
"""Helper to safely access outputs from ComfyUI nodes, which are often tuples."""
try:
return obj[index]
except (KeyError, TypeError):
# Fallback for custom nodes that might return a dictionary with a 'result' key
if isinstance(obj, Mapping) and "result" in obj:
return obj["result"][index]
raise
print("Loading models and instantiating nodes into memory. This may take a few minutes...")
# Instantiate Node Classes that will be used for loading and patching
cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
modelsamplingsd3 = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
pathchsageattentionkj = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()
# Load base models into CPU RAM initially
MODELS_AND_NODES["clip"] = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan")
unet_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
unet_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors")
# Chain all patching operations together for the final models
print("Applying all patches to models...")
# --- Low Noise Model Chain ---
model_low_with_lora = loraloadermodelonly.load_lora_model_only(
lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors",
strength_model=0.8, model=get_value_at_index(unet_low_noise, 0))
model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_with_lora, 0))
MODELS_AND_NODES["model_low_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))
# --- High Noise Model Chain ---
model_high_with_lora = loraloadermodelonly.load_lora_model_only(
lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors",
strength_model=0.8, model=get_value_at_index(unet_high_noise, 0))
model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_with_lora, 0))
MODELS_AND_NODES["model_high_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))
# Instantiate all other node classes ONCE and store them
MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]()
MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]()
MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]()
MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]()
# Move all final, fully-patched models to the GPU
print("Moving final models to GPU...")
model_loaders_final = [
MODELS_AND_NODES["clip"],
# MODELS_AND_NODES["vae"],
MODELS_AND_NODES["model_low_noise"],
MODELS_AND_NODES["model_high_noise"],
MODELS_AND_NODES["clip_vision"],
]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders_final
], force_patch_weights=True) # force_patch_weights permanently merges the LoRA
print("All models loaded, patched, and on GPU. Gradio app is ready.")
# --- 4. Application Logic and Gradio Interface ---
def calculate_video_dimensions(width, height, max_size=832, min_size=480):
"""Calculates video dimensions, ensuring they are multiples of 16."""
if width == height:
return min_size, min_size
aspect_ratio = width / height
if width > height:
video_width = max_size
video_height = int(max_size / aspect_ratio)
else:
video_height = max_size
video_width = int(max_size * aspect_ratio)
video_width = max(16, round(video_width / 16) * 16)
video_height = max(16, round(video_height / 16) * 16)
return video_width, video_height
def resize_and_crop_to_match(target_image, reference_image):
"""Resizes and center-crops the target image to match the reference image's dimensions."""
ref_width, ref_height = reference_image.size
target_width, target_height = target_image.size
scale = max(ref_width / target_width, ref_height / target_height)
new_width, new_height = int(target_width * scale), int(target_height * scale)
resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
return resized.crop((left, top, left + ref_width, top + ref_height))
@spaces.GPU(duration=120)
def generate_video(
start_image_pil,
end_image_pil,
prompt,
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
duration=33,
progress=gr.Progress(track_tqdm=True)
):
"""
Generates a video by interpolating between a start and end image, guided by a text prompt.
This function relies on globally pre-loaded models and pre-instantiated ComfyUI nodes.
"""
FPS = 16
# --- 1. Retrieve Pre-loaded and Pre-patched Models & Node Instances ---
# These are not re-instantiated; we are just getting references to the global objects.
clip = MODELS_AND_NODES["clip"]
vae = MODELS_AND_NODES["vae"]
model_low_final = MODELS_AND_NODES["model_low_noise"]
model_high_final = MODELS_AND_NODES["model_high_noise"]
clip_vision = MODELS_AND_NODES["clip_vision"]
cliptextencode = MODELS_AND_NODES["CLIPTextEncode"]
loadimage = MODELS_AND_NODES["LoadImage"]
clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"]
wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"]
ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"]
vaedecode = MODELS_AND_NODES["VAEDecode"]
createvideo = MODELS_AND_NODES["CreateVideo"]
savevideo = MODELS_AND_NODES["SaveVideo"]
# --- 2. Image Preprocessing for the Current Run ---
print("Preprocessing images with Pillow...")
processed_start_image = start_image_pil.copy()
processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil)
video_width, video_height = calculate_video_dimensions(processed_start_image.width, processed_start_image.height)
# Save processed images to temporary files for the LoadImage node
temp_dir = "input" # ComfyUI's default input directory
os.makedirs(temp_dir, exist_ok=True)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as start_file, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as end_file:
processed_start_image.save(start_file.name)
processed_end_image.save(end_file.name)
start_image_path = os.path.basename(start_file.name)
end_image_path = os.path.basename(end_file.name)
print(f"Images resized to {video_width}x{video_height} and saved temporarily.")
# --- 3. Execute the ComfyUI Workflow in Inference Mode ---
with torch.inference_mode():
progress(0.1, desc="Encoding text and images...")
# Encode prompts and vision models
positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))
start_image_loaded = loadimage.load_image(image=start_image_path)
end_image_loaded = loadimage.load_image(image=end_image_path)
clip_vision_encoded_start = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0))
clip_vision_encoded_end = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0))
progress(0.2, desc="Preparing initial latents...")
initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
width=video_width, height=video_height, length=duration, batch_size=1,
positive=get_value_at_index(positive_conditioning, 0),
negative=get_value_at_index(negative_conditioning, 0),
vae=get_value_at_index(vae, 0),
clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0),
clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0),
start_image=get_value_at_index(start_image_loaded, 0),
end_image=get_value_at_index(end_image_loaded, 0),
)
ksampler_positive = get_value_at_index(initial_latents, 0)
ksampler_negative = get_value_at_index(initial_latents, 1)
ksampler_latent = get_value_at_index(initial_latents, 2)
progress(0.5, desc="Denoising (Step 1/2)...")
latent_step1 = ksampleradvanced.sample(
add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
positive=ksampler_positive,
negative=ksampler_negative,
latent_image=ksampler_latent,
)
progress(0.7, desc="Denoising (Step 2/2)...")
latent_step2 = ksampleradvanced.sample(
add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
positive=ksampler_positive,
negative=ksampler_negative,
latent_image=get_value_at_index(latent_step1, 0),
)
progress(0.8, desc="Decoding VAE...")
decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0))
progress(0.9, desc="Creating and saving video...")
video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0))
# Save the video to ComfyUI's default output directory
save_result = savevideo.save_video(
filename_prefix="GradioVideo", format="mp4", codec="h264",
video=get_value_at_index(video_data, 0),
)
progress(1.0, desc="Done!")
# --- 4. Cleanup and Return ---
try:
os.remove(start_file.name)
os.remove(end_file.name)
except Exception as e:
print(f"Error cleaning up temporary files: {e}")
# Gradio video component expects a filepath relative to the root of the app
return f"output/{save_result['ui']['images'][0]['filename']}"
css = '''
.fillable{max-width: 1100px !important}
.dark .progress-text {color: white}
'''
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA on ZeroGPU")
with gr.Row():
with gr.Column():
with gr.Group():
with gr.Row():
start_image = gr.Image(type="pil", label="Start Frame")
end_image = gr.Image(type="pil", label="End Frame")
prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
with gr.Accordion("Advanced Settings", open=False, visible=False):
duration = gr.Radio(
[("Short (2s)", 33), ("Mid (4s)", 66)],
value=33,
label="Video Duration",
visible=False
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
visible=False
)
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
output_video = gr.Video(label="Generated Video", autoplay=True)
generate_button.click(
fn=generate_video,
inputs=[start_image, end_image, prompt, negative_prompt, duration],
outputs=output_video
)
gr.Examples(
examples=[
["poli_tower.png", "tower_takes_off.png", "the man turns around"],
["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"],
["capyabara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"],
],
inputs=[start_image, end_image, prompt],
outputs=output_video,
fn=generate_video,
cache_examples="lazy",
)
if __name__ == "__main__":
app.launch(share=True)