import gradio as gr import torch from diffusers import DiffusionPipeline import numpy as np import spaces import time from PIL import Image import io import base64 # Model configuration MODEL_ID = "hpcai-tech/Open-Sora-v2" # Initialize the pipeline @spaces.GPU(duration=1500) def load_model(): """Load the Open-Sora-v2 model""" try: pipe = DiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) pipe.to("cuda") # Enable memory efficient attention pipe.enable_attention_slicing() return pipe except Exception as e: print(f"Error loading model: {e}") return None # Global model variable model = None def initialize_model(): """Initialize the model on first request""" global model if model is None: model = load_model() return model is not None @spaces.GPU(duration=120) def generate_video( prompt: str, duration: int = 4, height: int = 720, width: int = 1280, num_inference_steps: int = 50, guidance_scale: float = 7.5, progress=gr.Progress() ) -> str: """ Generate a video from text prompt using Open-Sora-v2 Args: prompt: Text description of the video duration: Duration in seconds height: Video height width: Video width num_inference_steps: Number of denoising steps guidance_scale: Guidance scale for generation Returns: Path to the generated video file """ try: # Initialize model if not already done if not initialize_model(): raise Exception("Failed to initialize model") progress(0.1, desc="Initializing generation...") # Calculate number of frames based on duration (assuming 30 fps) num_frames = duration * 30 progress(0.2, desc="Starting video generation...") # Generate video frames result = model( prompt=prompt, num_frames=num_frames, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=torch.Generator().manual_seed(42) ) progress(0.8, desc="Processing frames...") # Save the generated video output_path = f"generated_video_{int(time.time())}.mp4" if hasattr(result, 'videos'): # Handle video output video_frames = result.videos[0] else: # Handle image sequence output video_frames = result.frames[0] if hasattr(result, 'frames') else result # Save as video file save_video(video_frames, output_path, fps=30) progress(1.0, desc="Video generation complete!") return output_path except Exception as e: print(f"Error generating video: {e}") raise gr.Error(f"Video generation failed: {str(e)}") def save_video(frames, output_path, fps=30): """Save video frames to MP4 file""" try: import cv2 # Convert frames to numpy if needed if torch.is_tensor(frames): frames = frames.cpu().numpy() # Ensure frames are in the correct format if len(frames.shape) == 4: frames = np.transpose(frames, (0, 2, 3, 1)) # TCHW -> THWC # Normalize frames to 0-255 frames = ((frames + 1.0) * 127.5).astype(np.uint8) # Get video dimensions height, width = frames[0].shape[:2] # Initialize video writer fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) # Write frames for frame in frames: if len(frame.shape) == 3: frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame) out.release() except ImportError: # Fallback: save as GIF if cv2 is not available from PIL import Image if torch.is_tensor(frames): frames = frames.cpu().numpy() if len(frames.shape) == 4: frames = np.transpose(frames, (0, 2, 3, 1)) frames = ((frames + 1.0) * 127.5).astype(np.uint8) images = [Image.fromarray(frame) for frame in frames] images[0].save( output_path.replace('.mp4', '.gif'), save_all=True, append_images=images[1:], duration=33, # ~30 fps loop=0 ) def create_interface(): """Create the Gradio interface""" with gr.Blocks( title="Text to Video - Open-Sora-v2", theme=gr.themes.Soft(), css=""" .header-text { text-align: center; font-size: 2em; margin-bottom: 0.5em; background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .subheader-text { text-align: center; color: #666; margin-bottom: 2em; } .generate-btn { background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); border: none; color: white; font-weight: bold; } .generate-btn:hover { background: linear-gradient(45deg, #764ba2 0%, #667eea 100%); } """ ) as demo: gr.Markdown("""