Spaces:
Paused
Paused
| import gradio as gr | |
| import tempfile | |
| import random | |
| import json | |
| import os | |
| import shutil | |
| import hashlib | |
| import uuid | |
| from pathlib import Path | |
| import time | |
| import logging | |
| import torch | |
| import numpy as np | |
| from typing import Dict, Any, List, Optional, Tuple, Union | |
| from diffusers import AutoencoderKLWan, WanPipeline | |
| from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | |
| from diffusers.utils import export_to_video | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| STORAGE_PATH = Path(os.getenv('STORAGE_PATH', './data')) | |
| LORA_PATH = STORAGE_PATH / "loras" | |
| OUTPUT_PATH = STORAGE_PATH / "output" | |
| MODEL_VERSION = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" | |
| DEFAULT_PROMPT_PREFIX = "" | |
| # Create necessary directories | |
| STORAGE_PATH.mkdir(parents=True, exist_ok=True) | |
| LORA_PATH.mkdir(parents=True, exist_ok=True) | |
| OUTPUT_PATH.mkdir(parents=True, exist_ok=True) | |
| # Global variables to track model state | |
| pipe = None | |
| current_lora_id = None | |
| def format_time(seconds: float) -> str: | |
| """Format time duration in seconds to human readable string""" | |
| hours = int(seconds // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| secs = int(seconds % 60) | |
| parts = [] | |
| if hours > 0: | |
| parts.append(f"{hours}h") | |
| if minutes > 0: | |
| parts.append(f"{minutes}m") | |
| if secs > 0 or not parts: | |
| parts.append(f"{secs}s") | |
| return " ".join(parts) | |
| def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]: | |
| """Upload a LoRA file and return a hash-based ID for future reference | |
| Args: | |
| file: Uploaded file object from Gradio | |
| Returns: | |
| Tuple[str, str]: Hash-based ID for the stored file (returned twice for both outputs) | |
| """ | |
| if file is None: | |
| return "", "" | |
| try: | |
| # Calculate SHA256 hash of the file | |
| sha256_hash = hashlib.sha256() | |
| with open(file.name, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| sha256_hash.update(chunk) | |
| file_hash = sha256_hash.hexdigest() | |
| # Create destination path using hash | |
| dest_path = LORA_PATH / f"{file_hash}.safetensors" | |
| # Check if file already exists | |
| if dest_path.exists(): | |
| logger.info("LoRA file already exists") | |
| return file_hash, file_hash | |
| # Copy the file to the destination | |
| shutil.copy(file.name, dest_path) | |
| logger.info(f"a new LoRA file has been uploaded") | |
| return file_hash, file_hash | |
| except Exception as e: | |
| logger.error(f"Error uploading LoRA file: {e}") | |
| raise gr.Error(f"Failed to upload LoRA file: {str(e)}") | |
| def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]: | |
| """Get the path to a LoRA file from its hash-based ID | |
| Args: | |
| lora_id: Hash-based ID of the stored LoRA file | |
| Returns: | |
| Path: Path to the LoRA file if found, None otherwise | |
| """ | |
| if not lora_id: | |
| return None | |
| # Check if file exists | |
| lora_path = LORA_PATH / f"{lora_id}.safetensors" | |
| if lora_path.exists(): | |
| return lora_path | |
| return None | |
| def get_or_create_pipeline( | |
| enable_cpu_offload: bool = True, | |
| flow_shift: float = 3.0 | |
| ) -> WanPipeline: | |
| """Get existing pipeline or create a new one if necessary | |
| Args: | |
| enable_cpu_offload: Whether to enable CPU offload | |
| flow_shift: Flow shift parameter for scheduler | |
| Returns: | |
| WanPipeline: The pipeline for generation | |
| """ | |
| global pipe | |
| if pipe is None: | |
| # Create a new pipeline | |
| logger.info("Creating new pipeline") | |
| # Load VAE | |
| vae = AutoencoderKLWan.from_pretrained(MODEL_VERSION, subfolder="vae", torch_dtype=torch.float32) | |
| # Load transformer | |
| pipe = WanPipeline.from_pretrained(MODEL_VERSION, vae=vae, torch_dtype=torch.bfloat16) | |
| # Configure scheduler | |
| pipe.scheduler = UniPCMultistepScheduler.from_config( | |
| pipe.scheduler.config, | |
| flow_shift=flow_shift | |
| ) | |
| # Move to GPU | |
| pipe.to("cuda") | |
| # Enable CPU offload if requested | |
| if enable_cpu_offload: | |
| logger.info("Enabling CPU offload") | |
| pipe.enable_model_cpu_offload() | |
| else: | |
| # Update existing pipeline's scheduler if needed | |
| if pipe.scheduler.config.flow_shift != flow_shift: | |
| logger.info(f"Updating scheduler flow_shift from {pipe.scheduler.config.flow_shift} to {flow_shift}") | |
| pipe.scheduler = UniPCMultistepScheduler.from_config( | |
| pipe.scheduler.config, | |
| flow_shift=flow_shift | |
| ) | |
| return pipe | |
| def manage_lora_weights(pipe: WanPipeline, lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]: | |
| """Manage LoRA weights, loading/unloading only when necessary | |
| Args: | |
| pipe: The pipeline to manage LoRA weights for | |
| lora_id: UUID of LoRA file to use | |
| lora_weight: Weight of LoRA contribution | |
| Returns: | |
| Tuple[bool, Optional[Path]]: (Is using LoRA, Path to LoRA file) | |
| """ | |
| global current_lora_id | |
| # Determine if we should use LoRA | |
| using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0 | |
| # If not using LoRA but we have one loaded, unload it | |
| if not using_lora and current_lora_id is not None: | |
| logger.info(f"Unloading current LoRA with ID") | |
| try: | |
| # Unload current LoRA weights | |
| pipe.unload_lora_weights() | |
| current_lora_id = None | |
| except Exception as e: | |
| logger.error(f"Error unloading LoRA weights: {e}") | |
| return False, None | |
| # If using LoRA, check if we need to change weights | |
| if using_lora: | |
| lora_path = get_lora_file_path(lora_id) | |
| if not lora_path: | |
| # Log the event but continue with base model | |
| logger.warning(f"LoRA file with ID {lora_id} not found. Using base model instead.") | |
| # If we had a LoRA loaded, unload it | |
| if current_lora_id is not None: | |
| logger.info(f"Unloading current LoRA") | |
| try: | |
| pipe.unload_lora_weights() | |
| except Exception as e: | |
| logger.error(f"Error unloading LoRA weights: {e}") | |
| current_lora_id = None | |
| return False, None | |
| # If LoRA ID changed, update weights | |
| if lora_id != current_lora_id: | |
| # If we had a LoRA loaded, unload it first | |
| if current_lora_id is not None: | |
| logger.info(f"Unloading current LoRA") | |
| try: | |
| pipe.unload_lora_weights() | |
| except Exception as e: | |
| logger.error(f"Error unloading LoRA weights: {e}") | |
| # Load new LoRA weights | |
| logger.info("Using a LoRA") | |
| try: | |
| pipe.load_lora_weights(lora_path, weight_name=str(lora_path), adapter_name="default") | |
| current_lora_id = lora_id | |
| except Exception as e: | |
| logger.error(f"Error loading LoRA weights: {e}") | |
| return False, None | |
| else: | |
| logger.info(f"Using currently loaded LoRA with ID") | |
| return True, lora_path | |
| return False, None | |
| def generate_video( | |
| prompt: str, | |
| negative_prompt: str, | |
| prompt_prefix: str, | |
| width: int, | |
| height: int, | |
| num_frames: int, | |
| guidance_scale: float, | |
| flow_shift: float, | |
| lora_id: Optional[str], | |
| lora_weight: float, | |
| inference_steps: int, | |
| fps: int = 16, | |
| seed: int = -1, | |
| enable_cpu_offload: bool = True, | |
| conditioning_image: Optional[str] = None, | |
| progress=gr.Progress() | |
| ) -> str: | |
| """Generate a video using the Wan model with optional LoRA weights | |
| Args: | |
| prompt: Text prompt for generation | |
| negative_prompt: Negative text prompt | |
| prompt_prefix: Prefix to add to all prompts | |
| width: Output video width | |
| height: Output video height | |
| num_frames: Number of frames to generate | |
| guidance_scale: Classifier-free guidance scale | |
| flow_shift: Flow shift parameter for scheduler | |
| lora_id: UUID of LoRA file to use | |
| lora_weight: Weight of LoRA contribution | |
| inference_steps: Number of inference steps | |
| fps: Frames per second for output video | |
| seed: Random seed (-1 for random) | |
| enable_cpu_offload: Whether to enable CPU offload for VRAM optimization | |
| conditioning_image: Path to conditioning image for image-to-video (not used in this app) | |
| progress: Gradio progress callback | |
| Returns: | |
| str: Video path | |
| """ | |
| global pipe, current_lora_id # Move the global declaration to the top of the function | |
| try: | |
| # Progress 0-5%: Initialize and check inputs | |
| progress(0.00, desc="Initializing generation") | |
| # Add prefix to prompt | |
| progress(0.02, desc="Processing prompt") | |
| if prompt_prefix and not prompt.startswith(prompt_prefix): | |
| full_prompt = f"{prompt_prefix}{prompt}" | |
| else: | |
| full_prompt = prompt | |
| # Create correct num_frames (should be 8*k + 1) | |
| adjusted_num_frames = ((num_frames - 1) // 8) * 8 + 1 | |
| if adjusted_num_frames != num_frames: | |
| logger.info(f"Adjusted number of frames from {num_frames} to {adjusted_num_frames} to match model requirements") | |
| num_frames = adjusted_num_frames | |
| # Set up random seed | |
| progress(0.03, desc="Setting up random seed") | |
| if seed == -1: | |
| seed = random.randint(0, 2**32 - 1) | |
| logger.info(f"Using randomly generated seed: {seed}") | |
| # Set random seeds for reproducibility | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| generator = torch.Generator(device="cuda") | |
| generator = generator.manual_seed(seed) | |
| # Progress 5-25%: Get or create pipeline | |
| progress(0.05, desc="Preparing model") | |
| pipe = get_or_create_pipeline(enable_cpu_offload, flow_shift) | |
| # Progress 25-40%: Manage LoRA weights | |
| progress(0.25, desc="Managing LoRA weights") | |
| using_lora, lora_path = manage_lora_weights(pipe, lora_id, lora_weight) | |
| # Create temporary file for the output | |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: | |
| output_path = temp_file.name | |
| # Progress 40-90%: Generate the video | |
| progress(0.40, desc="Starting video generation") | |
| # Set up timing for generation | |
| start_time = torch.cuda.Event(enable_timing=True) | |
| end_time = torch.cuda.Event(enable_timing=True) | |
| start_time.record() | |
| # Update progress once before generation starts | |
| progress(0.45, desc="Running diffusion process") | |
| # Generate the video without callback | |
| output = pipe( | |
| prompt=full_prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=inference_steps, | |
| generator=generator, | |
| # noo! don't do this! | |
| # we will implement the lora weight / scale later | |
| #cross_attention_kwargs={"scale": lora_weight} if using_lora else None | |
| ).frames[0] | |
| # Update progress after generation completes | |
| progress(0.90, desc="Generation complete") | |
| end_time.record() | |
| torch.cuda.synchronize() | |
| generation_time = start_time.elapsed_time(end_time) / 1000 # Convert to seconds | |
| logger.info(f"Video generation completed in {format_time(generation_time)}") | |
| # Progress 90-95%: Export video | |
| progress(0.90, desc="Exporting video") | |
| export_to_video(output, output_path, fps=fps) | |
| # Progress 95-100%: Save output and clean up | |
| progress(0.95, desc="Saving video") | |
| # Save a copy to our output directory with UUID for potential future reference | |
| output_id = str(uuid.uuid4()) | |
| saved_output_path = OUTPUT_PATH / f"{output_id}.mp4" | |
| shutil.copy(output_path, saved_output_path) | |
| logger.info(f"Saved video with ID: {output_id}") | |
| # No longer clear the pipeline since we're reusing it | |
| # Just clean up local variables | |
| progress(0.98, desc="Cleaning up resources") | |
| progress(1.0, desc="Generation complete") | |
| return output_path | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error generating video: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| # Clean up CUDA memory on error | |
| if pipe is not None: | |
| # Try to unload any LoRA weights on error | |
| if current_lora_id is not None: | |
| try: | |
| pipe.unload_lora_weights() | |
| current_lora_id = None | |
| except: | |
| pass | |
| # Release the pipeline on critical errors | |
| try: | |
| pipe = None | |
| torch.cuda.empty_cache() | |
| except: | |
| pass | |
| # Re-raise as Gradio error for UI display | |
| raise gr.Error(f"Error generating video: {str(e)}") | |
| # Create the Gradio app | |
| with gr.Blocks(title="Video Generation API") as app: | |
| with gr.Tabs(): | |
| # LoRA Upload Tab | |
| with gr.TabItem("1οΈβ£ Upload LoRA"): | |
| gr.Markdown("## Upload LoRA Weights") | |
| gr.Markdown("Upload your custom LoRA weights file to use for generation. The file will be automatically stored and you'll receive a unique hash-based ID.") | |
| with gr.Row(): | |
| lora_file = gr.File(label="LoRA File (safetensors format)") | |
| with gr.Row(): | |
| lora_id_output = gr.Textbox(label="LoRA Hash ID (use this in the generation tab)", interactive=False) | |
| # This will be connected after all components are defined | |
| # Video Generation Tab | |
| with gr.TabItem("2οΈβ£ Generate Video"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input parameters | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3 | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="Enter negative prompt here...", | |
| lines=3, | |
| value="worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background" | |
| ) | |
| prompt_prefix = gr.Textbox( | |
| label="Prompt Prefix", | |
| placeholder="Prefix to add to all prompts", | |
| value=DEFAULT_PROMPT_PREFIX | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=1280, | |
| step=8, | |
| value=1280 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=720, | |
| step=8, | |
| value=720 | |
| ) | |
| with gr.Row(): | |
| num_frames = gr.Slider( | |
| label="Number of Frames", | |
| minimum=9, | |
| maximum=257, | |
| step=8, | |
| value=49 | |
| ) | |
| fps = gr.Slider( | |
| label="FPS", | |
| minimum=1, | |
| maximum=60, | |
| step=1, | |
| value=16 | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=5.0 | |
| ) | |
| flow_shift = gr.Slider( | |
| label="Flow Shift", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=3.0 | |
| ) | |
| lora_id = gr.Textbox( | |
| label="LoRA ID (from upload tab)", | |
| placeholder="Enter your LoRA ID here...", | |
| ) | |
| with gr.Row(): | |
| lora_weight = gr.Slider( | |
| label="LoRA Weight", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.7 | |
| ) | |
| inference_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=30 | |
| ) | |
| seed = gr.Slider( | |
| label="Generation Seed (-1 for random)", | |
| minimum=-1, | |
| maximum=2147483647, # 2^31 - 1 | |
| step=1, | |
| value=-1 | |
| ) | |
| enable_cpu_offload = gr.Checkbox( | |
| label="Enable Model CPU Offload (for low-VRAM GPUs)", | |
| value=False | |
| ) | |
| generate_btn = gr.Button( | |
| "Generate Video", | |
| variant="primary" | |
| ) | |
| with gr.Column(scale=1): | |
| # Output component - just the video preview | |
| preview_video = gr.Video( | |
| label="Generated Video", | |
| interactive=False | |
| ) | |
| # Connect the generate button | |
| generate_btn.click( | |
| fn=generate_video, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| prompt_prefix, | |
| width, | |
| height, | |
| num_frames, | |
| guidance_scale, | |
| flow_shift, | |
| lora_id, | |
| lora_weight, | |
| inference_steps, | |
| fps, | |
| seed, | |
| enable_cpu_offload | |
| ], | |
| outputs=[ | |
| preview_video | |
| ] | |
| ) | |
| # Connect LoRA upload to both display fields | |
| lora_file.change( | |
| fn=upload_lora_file, | |
| inputs=[lora_file], | |
| outputs=[lora_id_output, lora_id] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch() | |