anntnikita's picture
Fix torch_dtype assignment in pipeline loading
56e5104 verified
"""
matrix_game_interface.py
========================
This script exposes a simple web interface for the Matrix‑Game 2.0 model via
Gradio. Given an initial image, the model produces a short video that
continues the scene forward in time. The code uses the diffusers library to
download and load the model from Hugging Face. It automatically selects CPU
or GPU based on availability.
To run this script you must have installed the dependencies in
`requirements.txt` and logged in to the Hugging Face Hub using your access
token. You can set the token at runtime via the `HF_TOKEN` environment
variable or by passing it into the constructor of the `MatrixGame` class.
Note: generating videos with Matrix‑Game 2.0 is computationally intensive and
requires a machine with significant memory. On a CPU the generation may be
very slow. For best results use a GPU with at least 24 GiB VRAM.
"""
from __future__ import annotations
import os
import tempfile
from typing import List, Optional
import numpy as np
from PIL import Image
import torch
from huggingface_hub import login
try:
# Import the generic video pipeline loader. Depending on your version of
# diffusers this symbol may live in different modules. We guard the import
# so that the script does not crash at import time on older versions.
from diffusers import AutoPipelineForVideo
except Exception:
AutoPipelineForVideo = None # type: ignore
try:
from diffusers import ImageToVideoPipeline
except Exception:
ImageToVideoPipeline = None # type: ignore
try:
import gradio as gr
except Exception:
gr = None # type: ignore
try:
from moviepy.editor import ImageSequenceClip
except Exception:
ImageSequenceClip = None # type: ignore
class MatrixGame:
"""Wrapper around the Matrix‑Game 2.0 model.
This class handles logging in to Hugging Face, downloading the model,
selecting the appropriate device and performing video generation. It
currently supports the universal mode, which uses the base distilled model
weights. Real‑time interactive control with mouse and keyboard inputs is
possible but not exposed through the Gradio UI.
"""
MODEL_ID: str = "Skywork/Matrix-Game-2.0"
def __init__(self, hf_token: Optional[str] = None, *, mode: str = "universal"):
self.mode = mode
self.hf_token = hf_token or os.environ.get("HF_TOKEN")
if not self.hf_token:
raise ValueError(
"A HuggingFace token must be provided either via the HF_TOKEN "
"environment variable or the hf_token argument."
)
# Authenticate with Hugging Face. This call is idempotent; if you're
# already logged in it does nothing.
login(token=self.hf_token, add_to_git_credential=False)
# Select compute device. Use GPU if available; otherwise fall back to CPU.
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use lower‑precision dtypes on GPU to save memory.
if self.device.type == "cuda":
self.dtype = torch.float16
else:
self.dtype = torch.float32
# Load the pipeline. We try the new `AutoPipelineForVideo` first since it
# automatically selects the proper class based on the model's
# configuration. If that is unavailable we fall back to
# `ImageToVideoPipeline`, which is supported by diffusers >=0.25.0.
pipeline = None
if AutoPipelineForVideo is not None:
try:
pipeline = AutoPipelineForVideo.from_pretrained(
self.MODEL_ID,
torch_dtype=self.dtype,
use_auth_token=self.hf_token,
)
except Exception as e:
print(f"AutoPipelineForVideo failed to load: {e}")
if pipeline is None and ImageToVideoPipeline is not None:
try:
pipeline = ImageToVideoPipeline.from_pretrained(
self.MODEL_ID,
torch_dtype=self.dtype,
use_auth_token=self.hf_token,
)
except Exception as e:
print(f"ImageToVideoPipeline failed to load: {e}")
if pipeline is None:
raise RuntimeError(
"Could not load a video pipeline for Matrix‑Game 2.0. Please "
"ensure diffusers is up to date (>=0.33) and that you have GPU "
"support installed."
)
self.pipeline = pipeline.to(self.device)
def generate_frames(self, image: Image.Image, num_frames: int = 8) -> List[Image.Image]:
"""Generate a sequence of frames given an initial image.
Args:
image: A PIL.Image that will act as the first frame of the video.
num_frames: The number of frames to generate (including the input).
Returns:
A list of PIL.Image objects representing the generated video frames.
"""
# Normalize and resize the input image to what the pipeline expects. The
# diffusers pipelines internally handle resizing, but explicitly
# converting to RGB ensures consistent results.
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL.Image")
image = image.convert("RGB")
# Some pipelines support passing `num_frames` directly to control the video
# length. Others may ignore the argument and use a default value. The
# Matrix‑Game model natively produces 16 frames per call. We allow the
# caller to request fewer frames; the pipeline will truncate the result
# accordingly.
with torch.autocast(self.device.type, dtype=self.dtype):
result = self.pipeline(image, num_frames=num_frames)
# The result is a simple namespace with a `frames` attribute containing
# the frames as PIL images.
frames: List[Image.Image] = getattr(result, "frames", None)
if frames is None:
# Some versions of diffusers return a dictionary with a
# "frames" key instead of an attribute.
frames = result.get("frames") # type: ignore
if frames is None:
raise RuntimeError("Unexpected output format from the pipeline")
# Limit to the requested number of frames if more were produced.
return frames[: num_frames]
def frames_to_video(self, frames: List[Image.Image], fps: int = 15) -> str:
"""Convert a list of frames into a temporary MP4 file.
Args:
frames: A list of PIL images.
fps: Frames per second for the output video.
Returns:
The file path to the generated MP4 video.
"""
if ImageSequenceClip is None:
raise ImportError(
"moviepy is required to assemble videos. Please install it with "
"`pip install moviepy` or use an alternative method."
)
# Convert PIL images to numpy arrays in uint8 format
clips = [np.array(frame) for frame in frames]
clip = ImageSequenceClip(clips, fps=fps)
# Write to a temporary file
tmp_dir = tempfile.mkdtemp(prefix="matrix_game_")
video_path = os.path.join(tmp_dir, "output.mp4")
clip.write_videofile(video_path, codec="libx264", audio=False, verbose=False, logger=None)
return video_path
def launch_interface():
"""Launch a Gradio interface for Matrix‑Game 2.0."""
if gr is None:
raise ImportError(
"Gradio is not installed. Please install it with `pip install gradio`."
)
# Instantiate the model wrapper once. This will download the weights
# automatically on first use. We read the token from the environment; if
# you prefer you can hard‑code the token here, but be mindful of
# security best practices.
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise RuntimeError(
"Please set the HF_TOKEN environment variable to your HuggingFace access "
"token before launching the interface."
)
matrix_game = MatrixGame(hf_token=hf_token)
def generate_fn(image: Image.Image, num_frames: int) -> str:
"""Callback invoked by Gradio to generate a video file."""
frames = matrix_game.generate_frames(image, num_frames=num_frames)
video_path = matrix_game.frames_to_video(frames, fps=15)
return video_path
with gr.Blocks() as demo:
gr.Markdown(
"""
# Matrix‑Game 2.0 Demo
Upload an image and choose how many frames to generate. The model
will synthesize a short video that extends the scene in real time.
Note that generation may take several minutes on machines without
high‑end GPUs.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Initial Frame")
num_frames = gr.Slider(
minimum=4,
maximum=32,
step=1,
value=16,
label="Number of Frames",
info="Total frames in the generated video (including the initial frame)",
)
generate_btn = gr.Button("Generate Video")
with gr.Column():
video_output = gr.Video(label="Generated Video", interactive=False)
generate_btn.click(
fn=generate_fn,
inputs=[image_input, num_frames],
outputs=video_output,
)
demo.launch()
if __name__ == "__main__":
launch_interface()