Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,675 Bytes
5d4f125 8222adf 5d2a97a 7ab2a0b 5d2a97a 8222adf 5d2a97a c259bb1 5d2a97a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import argparse
import time
from typing import Optional
import torch
from torchvision.io import write_video
from omegaconf import OmegaConf
from einops import rearrange
import gradio as gr
from pipeline import CausalInferencePipeline
from huggingface_hub import snapshot_download, hf_hub_download
# -----------------------------
# Globals (loaded once per process)
# -----------------------------
_PIPELINE: Optional[torch.nn.Module] = None
_DEVICE: Optional[torch.device] = None
def _ensure_gpu():
if not torch.cuda.is_available():
raise gr.Error("CUDA GPU is required to run this demo. Please run on a machine with an NVIDIA GPU.")
# Bind to GPU:0 by default
torch.cuda.set_device(0)
def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use_ema: bool) -> torch.nn.Module:
global _PIPELINE, _DEVICE
if _PIPELINE is not None:
return _PIPELINE
_ensure_gpu()
_DEVICE = torch.device("cuda:0")
# Load and merge configs
config = OmegaConf.load(config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
# Choose pipeline type based on config
pipeline = CausalInferencePipeline(config, device=_DEVICE)
# Load checkpoint if provided
if checkpoint_path and os.path.exists(checkpoint_path):
state_dict = torch.load(checkpoint_path, map_location="cpu")
if use_ema and 'generator_ema' in state_dict:
state_dict_to_load = state_dict['generator_ema']
# Remove possible FSDP prefix
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict_to_load.items():
new_state_dict[k.replace("_fsdp_wrapped_module.", "")] = v
state_dict_to_load = new_state_dict
else:
state_dict_to_load = state_dict.get('generator', state_dict)
pipeline.generator.load_state_dict(state_dict_to_load, strict=False)
# The codebase assumes bfloat16 on GPU
pipeline = pipeline.to(device=_DEVICE, dtype=torch.bfloat16)
pipeline.eval()
# Quick sanity path check for Wan models to give friendly errors
wan_dir = os.path.join('wan_models', 'Wan2.1-T2V-1.3B')
if not os.path.isdir(wan_dir):
raise gr.Error(
"Wan2.1-T2V-1.3B not found at 'wan_models/Wan2.1-T2V-1.3B'.\n"
"Please download it first, e.g.:\n"
"huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B"
)
_PIPELINE = pipeline
return _PIPELINE
def build_predict(config_path: str, checkpoint_path: Optional[str], output_dir: str, use_ema: bool):
os.makedirs(output_dir, exist_ok=True)
@spaces.GPU
def predict(prompt: str, num_frames: int) -> str:
if not prompt or not prompt.strip():
raise gr.Error("Please enter a non-empty text prompt.")
num_frames = int(num_frames)
if num_frames % 3 != 0 or not (21 <= num_frames <= 252):
raise gr.Error("Number of frames must be a multiple of 3 between 21 and 252.")
pipeline = _load_pipeline(config_path, checkpoint_path, use_ema)
# Prepare inputs
prompts = [prompt.strip()]
noise = torch.randn([1, num_frames, 16, 60, 104], device=_DEVICE, dtype=torch.bfloat16)
torch.set_grad_enabled(False)
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
video = pipeline.inference_rolling_forcing(
noise=noise,
text_prompts=prompts,
return_latents=False,
initial_latent=None,
)
# video: [B=1, T, C, H, W] in [0,1]
video = rearrange(video, 'b t c h w -> b t h w c')[0]
video_uint8 = (video * 255.0).clamp(0, 255).to(torch.uint8).cpu()
# Save to a unique filepath
safe_stub = prompt[:60].replace(' ', '_').replace('/', '_')
ts = int(time.time())
filepath = os.path.join(output_dir, f"{safe_stub or 'video'}_{ts}.mp4")
write_video(filepath, video_uint8, fps=16)
print(f"Saved generated video to {filepath}")
return filepath
return predict
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, default='configs/rolling_forcing_dmd.yaml',
help='Path to the model config')
parser.add_argument('--checkpoint_path', type=str, default='checkpoints/rolling_forcing_dmd.pt',
help='Path to rolling forcing checkpoint (.pt). If missing, will run with base weights only if available.')
parser.add_argument('--output_dir', type=str, default='videos/gradio', help='Where to save generated videos')
parser.add_argument('--no_ema', action='store_true', help='Disable EMA weights when loading checkpoint')
args = parser.parse_args()
# Download checkpoint from HuggingFace if not present
# 1️⃣ Equivalent to:
# huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
wan_model_dir = snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
local_dir="wan_models/Wan2.1-T2V-1.3B",
local_dir_use_symlinks=False, # same as --local-dir-use-symlinks False
)
print("Wan model downloaded to:", wan_model_dir)
# 2️⃣ Equivalent to:
# huggingface-cli download TencentARC/RollingForcing checkpoints/rolling_forcing_dmd.pt --local-dir .
rolling_ckpt_path = hf_hub_download(
repo_id="TencentARC/RollingForcing",
filename="checkpoints/rolling_forcing_dmd.pt",
local_dir=".", # where to store it
local_dir_use_symlinks=False,
)
print("RollingForcing checkpoint downloaded to:", rolling_ckpt_path)
predict = build_predict(
config_path=args.config_path,
checkpoint_path=args.checkpoint_path,
output_dir=args.output_dir,
use_ema=not args.no_ema,
)
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Text Prompt", lines=2, placeholder="A cinematic shot of a girl dancing in the sunset."),
gr.Slider(label="Number of Latent Frames", minimum=21, maximum=252, step=3, value=21),
],
outputs=gr.Video(label="Generated Video", format="mp4"),
title="Rolling Forcing: Autoregressive Long Video Diffusion in Real Time",
description=(
"Enter a prompt and generate a video using the Rolling Forcing pipeline.\n"
"**Note:** although Rolling Forcing generates videos autoregressivelty, current Gradio demo does not support streaming outputs, so the entire video will be generated before it is displayed.\n"
"\n"
"If you find this demo useful, please consider giving it a ⭐ star on [GitHub](https://github.com/TencentARC/RollingForcing)--your support is crucial for sustaining this open-source project. "
"You can also dive deeper by reading the [paper](https://arxiv.org/abs/2509.25161) or exploring the [project page](https://kunhao-liu.github.io/Rolling_Forcing_Webpage) for more details."
),
allow_flagging='never',
)
try:
# Gradio <= 3.x
demo.queue(concurrency_count=1, max_size=2)
except TypeError:
# Gradio >= 4.x
demo.queue(max_size=2)
demo.launch(show_error=True)
if __name__ == "__main__":
main()
|