Wan-2.2-5B / app.py
hjbfd's picture
Update app.py
cf74ea1 verified
raw
history blame
14.3 kB
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# wan2.2-main/gradio_ti2v.py
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from PIL import Image
import random
import numpy as np
import spaces
import cv2
import tempfile
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.utils import cache_video
import gc
# --- 1. Global Setup and Model Loading ---
print("Starting Gradio App for Wan 2.2 TI2V-5B...")
# Download model snapshots from Hugging Face Hub
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
print(f"Downloading/loading checkpoints for {repo_id}...")
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
print(f"Using checkpoints from {ckpt_dir}")
# Load the model configuration
TASK_NAME = 'ti2v-5B'
cfg = WAN_CONFIGS[TASK_NAME]
FIXED_FPS = 24
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 121
# Dimension calculation constants
MOD_VALUE = 32
DEFAULT_H_SLIDER_VALUE = 704
DEFAULT_W_SLIDER_VALUE = 1280
NEW_FORMULA_MAX_AREA = 1280.0 * 704.0
SLIDER_MIN_H, SLIDER_MAX_H = 128, 1280
SLIDER_MIN_W, SLIDER_MAX_W = 128, 1280
# Instantiate the pipeline in the global scope
print("Initializing WanTI2V pipeline...")
device = "cuda" if torch.cuda.is_available() else "cpu"
device_id = 0 if torch.cuda.is_available() else -1
pipeline = wan.WanTI2V(
config=cfg,
checkpoint_dir=ckpt_dir,
device_id=device_id,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=False,
convert_model_dtype=True,
)
print("Pipeline initialized and ready.")
# --- Helper Functions ---
def extract_first_frame_from_video(video_path):
"""
Extract the first frame from a video file.
Args:
video_path: Path to the video file
Returns:
PIL Image of the first frame, or None if extraction fails
"""
try:
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
cap.release()
if ret:
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return Image.fromarray(frame_rgb)
return None
except Exception as e:
print(f"Error extracting frame from video: {e}")
return None
def get_video_dimensions(video_path):
"""
Get the dimensions of a video file.
Args:
video_path: Path to the video file
Returns:
Tuple of (width, height) or None if extraction fails
"""
try:
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
return width, height
except Exception as e:
print(f"Error getting video dimensions: {e}")
return None
def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
min_slider_h, max_slider_h,
min_slider_w, max_slider_w,
default_h, default_w):
orig_w, orig_h = pil_image.size
if orig_w <= 0 or orig_h <= 0:
return default_h, default_w
aspect_ratio = orig_h / orig_w
calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
return new_h, new_w
def handle_media_upload_for_dims_wan(uploaded_media, current_h_val, current_w_val):
"""
Handle image or video upload and calculate appropriate dimensions.
Args:
uploaded_media: The uploaded file (can be image or video path)
current_h_val: Current height slider value
current_w_val: Current width slider value
Returns:
Tuple of (gr.update for height, gr.update for width, first frame as numpy array or None)
"""
if uploaded_media is None:
return (gr.update(value=DEFAULT_H_SLIDER_VALUE),
gr.update(value=DEFAULT_W_SLIDER_VALUE),
None)
try:
pil_image = None
# Check if it's a video file
if isinstance(uploaded_media, str) and uploaded_media.lower().endswith(('.mp4', '.avi', '.mov', '.mkv', '.webm')):
# Extract first frame from video
pil_image = extract_first_frame_from_video(uploaded_media)
if pil_image is None:
gr.Warning("Could not extract frame from video")
return (gr.update(value=DEFAULT_H_SLIDER_VALUE),
gr.update(value=DEFAULT_W_SLIDER_VALUE),
None)
else:
# Handle as image
if hasattr(uploaded_media, 'shape'): # numpy array
pil_image = Image.fromarray(uploaded_media).convert("RGB")
elif isinstance(uploaded_media, str): # file path
pil_image = Image.open(uploaded_media).convert("RGB")
else: # PIL Image
pil_image = uploaded_media
# Calculate dimensions
new_h, new_w = _calculate_new_dimensions_wan(
pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
)
# Convert PIL image to numpy array for display
display_image = np.array(pil_image)
return gr.update(value=new_h), gr.update(value=new_w), display_image
except Exception as e:
print(f"Error in handle_media_upload_for_dims_wan: {e}")
gr.Warning("Error processing uploaded file")
return (gr.update(value=DEFAULT_H_SLIDER_VALUE),
gr.update(value=DEFAULT_W_SLIDER_VALUE),
None)
def get_duration(video_input,
image_preview,
prompt,
height,
width,
duration_seconds,
sampling_steps,
guide_scale,
shift,
seed,
progress):
"""Calculate dynamic GPU duration based on parameters."""
return sampling_steps * 15
# --- 2. Gradio Inference Function ---
@spaces.GPU(duration=get_duration)
def generate_video(
video_input,
image_preview,
prompt,
height,
width,
duration_seconds,
sampling_steps=38,
guide_scale=cfg.sample_guide_scale,
shift=cfg.sample_shift,
seed=42,
progress=gr.Progress(track_tqdm=True)
):
"""
Generate a video from text prompt and optional image/video using the Wan 2.2 TI2V model.
Args:
video_input: Optional input video file path
image_preview: Preview image (numpy array) extracted from video or uploaded image
prompt: Text prompt describing the desired video
height: Target video height in pixels
width: Target video width in pixels
duration_seconds: Desired video duration in seconds
sampling_steps: Number of denoising steps for video generation
guide_scale: Guidance scale for classifier-free guidance
shift: Sample shift parameter for the model
seed: Random seed for reproducibility (-1 for random)
progress: Gradio progress tracker
Returns:
Path to the generated video file
"""
if seed == -1:
seed = random.randint(0, sys.maxsize)
# Ensure dimensions are multiples of MOD_VALUE
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
input_image = None
# Process video input if provided
if video_input is not None:
if isinstance(video_input, str) and video_input.lower().endswith(('.mp4', '.avi', '.mov', '.mkv', '.webm')):
input_image = extract_first_frame_from_video(video_input)
else:
# Fallback to image preview
if image_preview is not None:
input_image = Image.fromarray(image_preview).convert("RGB")
elif image_preview is not None:
# Use image preview if no video input
input_image = Image.fromarray(image_preview).convert("RGB")
# Resize image to match target dimensions if we have an input image
if input_image is not None:
input_image = input_image.resize((target_w, target_h))
# Calculate number of frames based on duration
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
# Create size string for the pipeline
size_str = f"{target_h}*{target_w}"
video_tensor = pipeline.generate(
input_prompt=prompt,
img=input_image, # Pass None for T2V, Image for I2V
size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
frame_num=num_frames,
shift=shift,
sample_solver='unipc',
sampling_steps=int(sampling_steps),
guide_scale=guide_scale,
seed=seed,
offload_model=True
)
# Save the video to a temporary file
video_path = cache_video(
tensor=video_tensor[None], # Add a batch dimension
save_file=None, # cache_video will create a temp file
fps=cfg.sample_fps,
normalize=True,
value_range=(-1, 1)
)
del video_tensor
gc.collect()
return video_path
# --- 3. Gradio Interface ---
css = ".gradio-container {max-width: 1200px !important; margin: 0 auto} #output_video {height: 500px;} #image_preview {height: 400px;}"
with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
gr.Markdown("# Wan 2.2 TI2V 5B - Video/Image to Video")
gr.Markdown("Generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model** with support for video input. [[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B), [[paper]](https://arxiv.org/abs/2503.20314)")
with gr.Row():
with gr.Column(scale=2):
video_input = gr.Video(
label="Upload Video or Image (optional - blank for text-to-video)",
sources=["upload"],
)
image_preview = gr.Image(
type="numpy",
label="Preview (first frame will be extracted from video)",
elem_id="image_preview",
interactive=False
)
prompt_input = gr.Textbox(
label="Prompt",
value="A beautiful waterfall in a lush jungle, cinematic.",
lines=3
)
duration_input = gr.Slider(
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
step=0.1,
value=2.0,
label="Duration (seconds)",
info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
height_input = gr.Slider(
minimum=SLIDER_MIN_H,
maximum=SLIDER_MAX_H,
step=MOD_VALUE,
value=DEFAULT_H_SLIDER_VALUE,
label=f"Output Height (multiple of {MOD_VALUE})"
)
width_input = gr.Slider(
minimum=SLIDER_MIN_W,
maximum=SLIDER_MAX_W,
step=MOD_VALUE,
value=DEFAULT_W_SLIDER_VALUE,
label=f"Output Width (multiple of {MOD_VALUE})"
)
steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
with gr.Column(scale=2):
video_output = gr.Video(label="Generated Video", elem_id="output_video")
run_button = gr.Button("Generate Video", variant="primary")
# Add video/image upload handler
video_input.upload(
fn=handle_media_upload_for_dims_wan,
inputs=[video_input, height_input, width_input],
outputs=[height_input, width_input, image_preview]
)
video_input.clear(
fn=lambda: (gr.update(value=DEFAULT_H_SLIDER_VALUE),
gr.update(value=DEFAULT_W_SLIDER_VALUE),
None),
inputs=[],
outputs=[height_input, width_input, image_preview]
)
example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
gr.Examples(
examples=[
[example_image_path, "The cat removes the glasses from its eyes.", 1088, 800, 1.5],
[None, "A cinematic shot of a boat sailing on a calm sea at sunset.", 704, 1280, 2.0],
[None, "Drone footage flying over a futuristic city with flying cars.", 704, 1280, 2.0],
],
inputs=[video_input, prompt_input, height_input, width_input, duration_input],
outputs=video_output,
fn=generate_video,
cache_examples="lazy",
)
run_button.click(
fn=generate_video,
inputs=[
video_input,
image_preview,
prompt_input,
height_input,
width_input,
duration_input,
steps_input,
scale_input,
shift_input,
seed_input
],
outputs=video_output
)
if __name__ == "__main__":
demo.launch(mcp_server=True)