File size: 7,675 Bytes
5d4f125
8222adf
 
 
5d2a97a
 
 
 
 
 
 
 
 
7ab2a0b
5d2a97a
 
 
 
 
 
 
 
8222adf
5d2a97a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c259bb1
5d2a97a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

import os
import argparse
import time
from typing import Optional

import torch
from torchvision.io import write_video
from omegaconf import OmegaConf
from einops import rearrange
import gradio as gr

from pipeline import CausalInferencePipeline
from huggingface_hub import snapshot_download, hf_hub_download


# -----------------------------
# Globals (loaded once per process)
# -----------------------------

_PIPELINE: Optional[torch.nn.Module] = None
_DEVICE: Optional[torch.device] = None


def _ensure_gpu():
    if not torch.cuda.is_available():
        raise gr.Error("CUDA GPU is required to run this demo. Please run on a machine with an NVIDIA GPU.")
    # Bind to GPU:0 by default
    torch.cuda.set_device(0)


def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use_ema: bool) -> torch.nn.Module:
    global _PIPELINE, _DEVICE
    if _PIPELINE is not None:
        return _PIPELINE

    _ensure_gpu()
    _DEVICE = torch.device("cuda:0")

    # Load and merge configs
    config = OmegaConf.load(config_path)
    default_config = OmegaConf.load("configs/default_config.yaml")
    config = OmegaConf.merge(default_config, config)

    # Choose pipeline type based on config
    pipeline = CausalInferencePipeline(config, device=_DEVICE)


    # Load checkpoint if provided
    if checkpoint_path and os.path.exists(checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location="cpu")
        if use_ema and 'generator_ema' in state_dict:
            state_dict_to_load = state_dict['generator_ema']
            # Remove possible FSDP prefix
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict_to_load.items():
                new_state_dict[k.replace("_fsdp_wrapped_module.", "")] = v
            state_dict_to_load = new_state_dict
        else:
            state_dict_to_load = state_dict.get('generator', state_dict)
        pipeline.generator.load_state_dict(state_dict_to_load, strict=False)

    # The codebase assumes bfloat16 on GPU
    pipeline = pipeline.to(device=_DEVICE, dtype=torch.bfloat16)
    pipeline.eval()

    # Quick sanity path check for Wan models to give friendly errors
    wan_dir = os.path.join('wan_models', 'Wan2.1-T2V-1.3B')
    if not os.path.isdir(wan_dir):
        raise gr.Error(
            "Wan2.1-T2V-1.3B not found at 'wan_models/Wan2.1-T2V-1.3B'.\n"
            "Please download it first, e.g.:\n"
            "huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B"
        )

    _PIPELINE = pipeline
    return _PIPELINE


def build_predict(config_path: str, checkpoint_path: Optional[str], output_dir: str, use_ema: bool):
    os.makedirs(output_dir, exist_ok=True)

    @spaces.GPU 
    def predict(prompt: str, num_frames: int) -> str:
        if not prompt or not prompt.strip():
            raise gr.Error("Please enter a non-empty text prompt.")

        num_frames = int(num_frames)
        if num_frames % 3 != 0 or not (21 <= num_frames <= 252):
            raise gr.Error("Number of frames must be a multiple of 3 between 21 and 252.")

        pipeline = _load_pipeline(config_path, checkpoint_path, use_ema)

        # Prepare inputs
        prompts = [prompt.strip()]
        noise = torch.randn([1, num_frames, 16, 60, 104], device=_DEVICE, dtype=torch.bfloat16)

        torch.set_grad_enabled(False)
        with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
            video = pipeline.inference_rolling_forcing(
                noise=noise,
                text_prompts=prompts,
                return_latents=False,
                initial_latent=None,
            )

        # video: [B=1, T, C, H, W] in [0,1]
        video = rearrange(video, 'b t c h w -> b t h w c')[0]
        video_uint8 = (video * 255.0).clamp(0, 255).to(torch.uint8).cpu()

        # Save to a unique filepath
        safe_stub = prompt[:60].replace(' ', '_').replace('/', '_')
        ts = int(time.time())
        filepath = os.path.join(output_dir, f"{safe_stub or 'video'}_{ts}.mp4")
        write_video(filepath, video_uint8, fps=16)
        print(f"Saved generated video to {filepath}")

        return filepath

    return predict


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/rolling_forcing_dmd.yaml',
                        help='Path to the model config')
    parser.add_argument('--checkpoint_path', type=str, default='checkpoints/rolling_forcing_dmd.pt',
                        help='Path to rolling forcing checkpoint (.pt). If missing, will run with base weights only if available.')
    parser.add_argument('--output_dir', type=str, default='videos/gradio', help='Where to save generated videos')
    parser.add_argument('--no_ema', action='store_true', help='Disable EMA weights when loading checkpoint')
    args = parser.parse_args()


    # Download checkpoint from HuggingFace if not present
    # 1️⃣ Equivalent to:
    # huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
    wan_model_dir = snapshot_download(
        repo_id="Wan-AI/Wan2.1-T2V-1.3B",
        local_dir="wan_models/Wan2.1-T2V-1.3B",
        local_dir_use_symlinks=False,  # same as --local-dir-use-symlinks False
    )
    print("Wan model downloaded to:", wan_model_dir)

    # 2️⃣ Equivalent to:
    # huggingface-cli download TencentARC/RollingForcing checkpoints/rolling_forcing_dmd.pt --local-dir .
    rolling_ckpt_path = hf_hub_download(
        repo_id="TencentARC/RollingForcing",
        filename="checkpoints/rolling_forcing_dmd.pt",
        local_dir=".",  # where to store it
        local_dir_use_symlinks=False,
    )
    print("RollingForcing checkpoint downloaded to:", rolling_ckpt_path)

    predict = build_predict(
        config_path=args.config_path,
        checkpoint_path=args.checkpoint_path,
        output_dir=args.output_dir,
        use_ema=not args.no_ema,
    )

    demo = gr.Interface(
        fn=predict,
        inputs=[
            gr.Textbox(label="Text Prompt", lines=2, placeholder="A cinematic shot of a girl dancing in the sunset."),
            gr.Slider(label="Number of Latent Frames", minimum=21, maximum=252, step=3, value=21),
        ],
        outputs=gr.Video(label="Generated Video", format="mp4"),
        title="Rolling Forcing: Autoregressive Long Video Diffusion in Real Time",
        description=(
            "Enter a prompt and generate a video using the Rolling Forcing pipeline.\n"
            "**Note:** although Rolling Forcing generates videos autoregressivelty, current Gradio demo does not support streaming outputs, so the entire video will be generated before it is displayed.\n"
            "\n"
            "If you find this demo useful, please consider giving it a ⭐ star on [GitHub](https://github.com/TencentARC/RollingForcing)--your support is crucial for sustaining this open-source project. "
            "You can also dive deeper by reading the [paper](https://arxiv.org/abs/2509.25161) or exploring the [project page](https://kunhao-liu.github.io/Rolling_Forcing_Webpage) for more details."
        ),
        allow_flagging='never',
    )

    try:
        # Gradio <= 3.x
        demo.queue(concurrency_count=1, max_size=2)
    except TypeError:
        # Gradio >= 4.x
        demo.queue(max_size=2)
    demo.launch(show_error=True)


if __name__ == "__main__":
    main()