RollingForcing / app.py
kunhaokhliu
update
c259bb1
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import argparse
import time
from typing import Optional
import torch
from torchvision.io import write_video
from omegaconf import OmegaConf
from einops import rearrange
import gradio as gr
from pipeline import CausalInferencePipeline
from huggingface_hub import snapshot_download, hf_hub_download
# -----------------------------
# Globals (loaded once per process)
# -----------------------------
_PIPELINE: Optional[torch.nn.Module] = None
_DEVICE: Optional[torch.device] = None
def _ensure_gpu():
if not torch.cuda.is_available():
raise gr.Error("CUDA GPU is required to run this demo. Please run on a machine with an NVIDIA GPU.")
# Bind to GPU:0 by default
torch.cuda.set_device(0)
def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use_ema: bool) -> torch.nn.Module:
global _PIPELINE, _DEVICE
if _PIPELINE is not None:
return _PIPELINE
_ensure_gpu()
_DEVICE = torch.device("cuda:0")
# Load and merge configs
config = OmegaConf.load(config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
# Choose pipeline type based on config
pipeline = CausalInferencePipeline(config, device=_DEVICE)
# Load checkpoint if provided
if checkpoint_path and os.path.exists(checkpoint_path):
state_dict = torch.load(checkpoint_path, map_location="cpu")
if use_ema and 'generator_ema' in state_dict:
state_dict_to_load = state_dict['generator_ema']
# Remove possible FSDP prefix
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict_to_load.items():
new_state_dict[k.replace("_fsdp_wrapped_module.", "")] = v
state_dict_to_load = new_state_dict
else:
state_dict_to_load = state_dict.get('generator', state_dict)
pipeline.generator.load_state_dict(state_dict_to_load, strict=False)
# The codebase assumes bfloat16 on GPU
pipeline = pipeline.to(device=_DEVICE, dtype=torch.bfloat16)
pipeline.eval()
# Quick sanity path check for Wan models to give friendly errors
wan_dir = os.path.join('wan_models', 'Wan2.1-T2V-1.3B')
if not os.path.isdir(wan_dir):
raise gr.Error(
"Wan2.1-T2V-1.3B not found at 'wan_models/Wan2.1-T2V-1.3B'.\n"
"Please download it first, e.g.:\n"
"huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B"
)
_PIPELINE = pipeline
return _PIPELINE
def build_predict(config_path: str, checkpoint_path: Optional[str], output_dir: str, use_ema: bool):
os.makedirs(output_dir, exist_ok=True)
@spaces.GPU
def predict(prompt: str, num_frames: int) -> str:
if not prompt or not prompt.strip():
raise gr.Error("Please enter a non-empty text prompt.")
num_frames = int(num_frames)
if num_frames % 3 != 0 or not (21 <= num_frames <= 252):
raise gr.Error("Number of frames must be a multiple of 3 between 21 and 252.")
pipeline = _load_pipeline(config_path, checkpoint_path, use_ema)
# Prepare inputs
prompts = [prompt.strip()]
noise = torch.randn([1, num_frames, 16, 60, 104], device=_DEVICE, dtype=torch.bfloat16)
torch.set_grad_enabled(False)
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
video = pipeline.inference_rolling_forcing(
noise=noise,
text_prompts=prompts,
return_latents=False,
initial_latent=None,
)
# video: [B=1, T, C, H, W] in [0,1]
video = rearrange(video, 'b t c h w -> b t h w c')[0]
video_uint8 = (video * 255.0).clamp(0, 255).to(torch.uint8).cpu()
# Save to a unique filepath
safe_stub = prompt[:60].replace(' ', '_').replace('/', '_')
ts = int(time.time())
filepath = os.path.join(output_dir, f"{safe_stub or 'video'}_{ts}.mp4")
write_video(filepath, video_uint8, fps=16)
print(f"Saved generated video to {filepath}")
return filepath
return predict
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, default='configs/rolling_forcing_dmd.yaml',
help='Path to the model config')
parser.add_argument('--checkpoint_path', type=str, default='checkpoints/rolling_forcing_dmd.pt',
help='Path to rolling forcing checkpoint (.pt). If missing, will run with base weights only if available.')
parser.add_argument('--output_dir', type=str, default='videos/gradio', help='Where to save generated videos')
parser.add_argument('--no_ema', action='store_true', help='Disable EMA weights when loading checkpoint')
args = parser.parse_args()
# Download checkpoint from HuggingFace if not present
# 1️⃣ Equivalent to:
# huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
wan_model_dir = snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
local_dir="wan_models/Wan2.1-T2V-1.3B",
local_dir_use_symlinks=False, # same as --local-dir-use-symlinks False
)
print("Wan model downloaded to:", wan_model_dir)
# 2️⃣ Equivalent to:
# huggingface-cli download TencentARC/RollingForcing checkpoints/rolling_forcing_dmd.pt --local-dir .
rolling_ckpt_path = hf_hub_download(
repo_id="TencentARC/RollingForcing",
filename="checkpoints/rolling_forcing_dmd.pt",
local_dir=".", # where to store it
local_dir_use_symlinks=False,
)
print("RollingForcing checkpoint downloaded to:", rolling_ckpt_path)
predict = build_predict(
config_path=args.config_path,
checkpoint_path=args.checkpoint_path,
output_dir=args.output_dir,
use_ema=not args.no_ema,
)
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Text Prompt", lines=2, placeholder="A cinematic shot of a girl dancing in the sunset."),
gr.Slider(label="Number of Latent Frames", minimum=21, maximum=252, step=3, value=21),
],
outputs=gr.Video(label="Generated Video", format="mp4"),
title="Rolling Forcing: Autoregressive Long Video Diffusion in Real Time",
description=(
"Enter a prompt and generate a video using the Rolling Forcing pipeline.\n"
"**Note:** although Rolling Forcing generates videos autoregressivelty, current Gradio demo does not support streaming outputs, so the entire video will be generated before it is displayed.\n"
"\n"
"If you find this demo useful, please consider giving it a ⭐ star on [GitHub](https://github.com/TencentARC/RollingForcing)--your support is crucial for sustaining this open-source project. "
"You can also dive deeper by reading the [paper](https://arxiv.org/abs/2509.25161) or exploring the [project page](https://kunhao-liu.github.io/Rolling_Forcing_Webpage) for more details."
),
allow_flagging='never',
)
try:
# Gradio <= 3.x
demo.queue(concurrency_count=1, max_size=2)
except TypeError:
# Gradio >= 4.x
demo.queue(max_size=2)
demo.launch(show_error=True)
if __name__ == "__main__":
main()