Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import shutil | |
| import sys | |
| import subprocess | |
| import asyncio | |
| import uuid | |
| import random | |
| import tempfile | |
| from typing import Sequence, Mapping, Any, Union | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| # --- 1. Model Download and Setup --- | |
| def hf_hub_download_local(repo_id, filename, local_dir, **kwargs): | |
| """Downloads a file from Hugging Face Hub and symlinks it to a local directory.""" | |
| downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs) | |
| os.makedirs(local_dir, exist_ok=True) | |
| base_filename = os.path.basename(filename) | |
| target_path = os.path.join(local_dir, base_filename) | |
| # Remove existing symlink or file to avoid errors | |
| if os.path.exists(target_path) or os.path.islink(target_path): | |
| os.remove(target_path) | |
| os.symlink(downloaded_path, target_path) | |
| return target_path | |
| print("Downloading models from Hugging Face Hub...") | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders") | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae") | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision") | |
| hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras") | |
| hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras") | |
| print("Downloads complete.") | |
| # --- 2. ComfyUI Backend Initialization --- | |
| def find_path(name: str, path: str = None) -> str: | |
| """Recursively finds a directory with a given name.""" | |
| if path is None: path = os.getcwd() | |
| if name in os.listdir(path): return os.path.join(path, name) | |
| parent_directory = os.path.dirname(path) | |
| return find_path(name, parent_directory) if parent_directory != path else None | |
| def add_comfyui_directory_to_sys_path() -> None: | |
| """Adds the ComfyUI directory to sys.path for imports.""" | |
| comfyui_path = find_path("ComfyUI") | |
| if comfyui_path and os.path.isdir(comfyui_path): | |
| sys.path.append(comfyui_path) | |
| print(f"'{comfyui_path}' added to sys.path") | |
| def add_extra_model_paths() -> None: | |
| """Initializes ComfyUI's folder_paths with custom paths.""" | |
| from main import apply_custom_paths | |
| apply_custom_paths() | |
| def import_custom_nodes() -> None: | |
| """Initializes all ComfyUI custom nodes.""" | |
| import nodes | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| loop.run_until_complete(nodes.init_extra_nodes(init_custom_nodes=True)) | |
| print("Setting up ComfyUI paths and nodes...") | |
| add_comfyui_directory_to_sys_path() | |
| add_extra_model_paths() | |
| import_custom_nodes() | |
| print("ComfyUI setup complete.") | |
| # --- 3. Global Model & Node Loading and Patching --- | |
| from nodes import NODE_CLASS_MAPPINGS | |
| import folder_paths | |
| from comfy import model_management | |
| # Set VRAM mode to HIGH to prevent models from being offloaded from GPU after use. | |
| # model_management.vram_state = model_management.VRAMState.HIGH_VRAM | |
| MODELS_AND_NODES = {} | |
| def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
| """Helper to safely access outputs from ComfyUI nodes, which are often tuples.""" | |
| try: | |
| return obj[index] | |
| except (KeyError, TypeError): | |
| # Fallback for custom nodes that might return a dictionary with a 'result' key | |
| if isinstance(obj, Mapping) and "result" in obj: | |
| return obj["result"][index] | |
| raise | |
| print("Loading models and instantiating nodes into memory. This may take a few minutes...") | |
| # Instantiate Node Classes that will be used for loading and patching | |
| cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]() | |
| unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() | |
| vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() | |
| clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() | |
| loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() | |
| modelsamplingsd3 = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]() | |
| pathchsageattentionkj = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]() | |
| # Load base models into CPU RAM initially | |
| MODELS_AND_NODES["clip"] = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan") | |
| unet_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default") | |
| unet_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default") | |
| MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors") | |
| MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors") | |
| # Chain all patching operations together for the final models | |
| print("Applying all patches to models...") | |
| # --- Low Noise Model Chain --- | |
| model_low_with_lora = loraloadermodelonly.load_lora_model_only( | |
| lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", | |
| strength_model=0.8, model=get_value_at_index(unet_low_noise, 0)) | |
| model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_with_lora, 0)) | |
| MODELS_AND_NODES["model_low_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0)) | |
| # --- High Noise Model Chain --- | |
| model_high_with_lora = loraloadermodelonly.load_lora_model_only( | |
| lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", | |
| strength_model=0.8, model=get_value_at_index(unet_high_noise, 0)) | |
| model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_with_lora, 0)) | |
| MODELS_AND_NODES["model_high_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0)) | |
| # Instantiate all other node classes ONCE and store them | |
| MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() | |
| MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]() | |
| MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]() | |
| MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]() | |
| MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]() | |
| MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]() | |
| MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]() | |
| MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]() | |
| # Move all final, fully-patched models to the GPU | |
| print("Moving final models to GPU...") | |
| model_loaders_final = [ | |
| MODELS_AND_NODES["clip"], | |
| # MODELS_AND_NODES["vae"], | |
| MODELS_AND_NODES["model_low_noise"], | |
| MODELS_AND_NODES["model_high_noise"], | |
| MODELS_AND_NODES["clip_vision"], | |
| ] | |
| model_management.load_models_gpu([ | |
| loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders_final | |
| ], force_patch_weights=True) # force_patch_weights permanently merges the LoRA | |
| print("All models loaded, patched, and on GPU. Gradio app is ready.") | |
| # --- 4. Application Logic and Gradio Interface --- | |
| def calculate_video_dimensions(width, height, max_size=832, min_size=480): | |
| """Calculates video dimensions, ensuring they are multiples of 16.""" | |
| if width == height: | |
| return min_size, min_size | |
| aspect_ratio = width / height | |
| if width > height: | |
| video_width = max_size | |
| video_height = int(max_size / aspect_ratio) | |
| else: | |
| video_height = max_size | |
| video_width = int(max_size * aspect_ratio) | |
| video_width = max(16, round(video_width / 16) * 16) | |
| video_height = max(16, round(video_height / 16) * 16) | |
| return video_width, video_height | |
| def resize_and_crop_to_match(target_image, reference_image): | |
| """Resizes and center-crops the target image to match the reference image's dimensions.""" | |
| ref_width, ref_height = reference_image.size | |
| target_width, target_height = target_image.size | |
| scale = max(ref_width / target_width, ref_height / target_height) | |
| new_width, new_height = int(target_width * scale), int(target_height * scale) | |
| resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2 | |
| return resized.crop((left, top, left + ref_width, top + ref_height)) | |
| def generate_video( | |
| start_image_pil, | |
| end_image_pil, | |
| prompt, | |
| negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", | |
| duration=33, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """ | |
| Generates a video by interpolating between a start and end image, guided by a text prompt. | |
| This function relies on globally pre-loaded models and pre-instantiated ComfyUI nodes. | |
| """ | |
| FPS = 16 | |
| # --- 1. Retrieve Pre-loaded and Pre-patched Models & Node Instances --- | |
| # These are not re-instantiated; we are just getting references to the global objects. | |
| clip = MODELS_AND_NODES["clip"] | |
| vae = MODELS_AND_NODES["vae"] | |
| model_low_final = MODELS_AND_NODES["model_low_noise"] | |
| model_high_final = MODELS_AND_NODES["model_high_noise"] | |
| clip_vision = MODELS_AND_NODES["clip_vision"] | |
| cliptextencode = MODELS_AND_NODES["CLIPTextEncode"] | |
| loadimage = MODELS_AND_NODES["LoadImage"] | |
| clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"] | |
| wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"] | |
| ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"] | |
| vaedecode = MODELS_AND_NODES["VAEDecode"] | |
| createvideo = MODELS_AND_NODES["CreateVideo"] | |
| savevideo = MODELS_AND_NODES["SaveVideo"] | |
| # --- 2. Image Preprocessing for the Current Run --- | |
| print("Preprocessing images with Pillow...") | |
| processed_start_image = start_image_pil.copy() | |
| processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil) | |
| video_width, video_height = calculate_video_dimensions(processed_start_image.width, processed_start_image.height) | |
| # Save processed images to temporary files for the LoadImage node | |
| temp_dir = "input" # ComfyUI's default input directory | |
| os.makedirs(temp_dir, exist_ok=True) | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as start_file, \ | |
| tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as end_file: | |
| processed_start_image.save(start_file.name) | |
| processed_end_image.save(end_file.name) | |
| start_image_path = os.path.basename(start_file.name) | |
| end_image_path = os.path.basename(end_file.name) | |
| print(f"Images resized to {video_width}x{video_height} and saved temporarily.") | |
| # --- 3. Execute the ComfyUI Workflow in Inference Mode --- | |
| with torch.inference_mode(): | |
| progress(0.1, desc="Encoding text and images...") | |
| # Encode prompts and vision models | |
| positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0)) | |
| negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0)) | |
| start_image_loaded = loadimage.load_image(image=start_image_path) | |
| end_image_loaded = loadimage.load_image(image=end_image_path) | |
| clip_vision_encoded_start = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0)) | |
| clip_vision_encoded_end = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0)) | |
| progress(0.2, desc="Preparing initial latents...") | |
| initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED( | |
| width=video_width, height=video_height, length=duration, batch_size=1, | |
| positive=get_value_at_index(positive_conditioning, 0), | |
| negative=get_value_at_index(negative_conditioning, 0), | |
| vae=get_value_at_index(vae, 0), | |
| clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0), | |
| clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0), | |
| start_image=get_value_at_index(start_image_loaded, 0), | |
| end_image=get_value_at_index(end_image_loaded, 0), | |
| ) | |
| ksampler_positive = get_value_at_index(initial_latents, 0) | |
| ksampler_negative = get_value_at_index(initial_latents, 1) | |
| ksampler_latent = get_value_at_index(initial_latents, 2) | |
| progress(0.5, desc="Denoising (Step 1/2)...") | |
| latent_step1 = ksampleradvanced.sample( | |
| add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, | |
| sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4, | |
| return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0), | |
| positive=ksampler_positive, | |
| negative=ksampler_negative, | |
| latent_image=ksampler_latent, | |
| ) | |
| progress(0.7, desc="Denoising (Step 2/2)...") | |
| latent_step2 = ksampleradvanced.sample( | |
| add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, | |
| sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000, | |
| return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0), | |
| positive=ksampler_positive, | |
| negative=ksampler_negative, | |
| latent_image=get_value_at_index(latent_step1, 0), | |
| ) | |
| progress(0.8, desc="Decoding VAE...") | |
| decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0)) | |
| progress(0.9, desc="Creating and saving video...") | |
| video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0)) | |
| # Save the video to ComfyUI's default output directory | |
| save_result = savevideo.save_video( | |
| filename_prefix="GradioVideo", format="mp4", codec="h264", | |
| video=get_value_at_index(video_data, 0), | |
| ) | |
| progress(1.0, desc="Done!") | |
| # --- 4. Cleanup and Return --- | |
| try: | |
| os.remove(start_file.name) | |
| os.remove(end_file.name) | |
| except Exception as e: | |
| print(f"Error cleaning up temporary files: {e}") | |
| # Gradio video component expects a filepath relative to the root of the app | |
| return f"output/{save_result['ui']['images'][0]['filename']}" | |
| css = ''' | |
| .fillable{max-width: 1100px !important} | |
| .dark .progress-text {color: white} | |
| ''' | |
| with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: | |
| gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") | |
| gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA on ZeroGPU") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| with gr.Row(): | |
| start_image = gr.Image(type="pil", label="Start Frame") | |
| end_image = gr.Image(type="pil", label="End Frame") | |
| prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") | |
| with gr.Accordion("Advanced Settings", open=False, visible=False): | |
| duration = gr.Radio( | |
| [("Short (2s)", 33), ("Mid (4s)", 66)], | |
| value=33, | |
| label="Video Duration", | |
| visible=False | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", | |
| visible=False | |
| ) | |
| generate_button = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video", autoplay=True) | |
| generate_button.click( | |
| fn=generate_video, | |
| inputs=[start_image, end_image, prompt, negative_prompt, duration], | |
| outputs=output_video | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["poli_tower.png", "tower_takes_off.png", "the man turns around"], | |
| ["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"], | |
| ["capyabara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"], | |
| ], | |
| inputs=[start_image, end_image, prompt], | |
| outputs=output_video, | |
| fn=generate_video, | |
| cache_examples="lazy", | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(share=True) |