Spaces:
Paused
Paused
| import gradio as gr | |
| import replicate | |
| import os | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import time | |
| import tempfile | |
| import base64 | |
| import spaces | |
| import torch | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
| from diffusers.models.transformers.transformer_wan import WanTransformer3DModel | |
| from diffusers.utils.export_utils import export_to_video | |
| import numpy as np | |
| import random | |
| import gc | |
| # =========================== | |
| # Configuration | |
| # =========================== | |
| # Set up Replicate API key | |
| os.environ['REPLICATE_API_TOKEN'] = os.getenv('REPLICATE_API_TOKEN') | |
| # Video Model Configuration | |
| VIDEO_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" | |
| LANDSCAPE_WIDTH = 832 | |
| LANDSCAPE_HEIGHT = 480 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| FIXED_FPS = 16 | |
| MIN_FRAMES_MODEL = 8 | |
| MAX_FRAMES_MODEL = 81 | |
| MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1) | |
| MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1) | |
| default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation" | |
| default_negative_prompt = "static, still, no motion, frozen" | |
| # =========================== | |
| # Initialize Video Pipeline | |
| # =========================== | |
| # Initialize once on startup | |
| video_pipe = None | |
| video_pipeline_ready = False | |
| def initialize_video_pipeline(): | |
| global video_pipe, video_pipeline_ready | |
| if video_pipe is None and not video_pipeline_ready: | |
| try: | |
| print("Starting video pipeline initialization...") | |
| # Install PyTorch 2.8 (if needed) | |
| os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') | |
| # Import LoRA loading utilities | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| video_pipe = WanImageToVideoPipeline.from_pretrained(VIDEO_MODEL_ID, | |
| transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
| subfolder='transformer', | |
| torch_dtype=torch.bfloat16, | |
| device_map='cuda', | |
| ), | |
| transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
| subfolder='transformer_2', | |
| torch_dtype=torch.bfloat16, | |
| device_map='cuda', | |
| ), | |
| torch_dtype=torch.bfloat16, | |
| ).to('cuda') | |
| # Clear memory after loading | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Load Lightning LoRA | |
| try: | |
| print("Loading Lightning LoRA adapter...") | |
| video_pipe.transformer.load_adapter("Lightx2v/lightx2v_I2V_14B_480p_cfg_step_4", adapter_name="lightx2v") | |
| video_pipe.transformer_2.load_adapter("Lightx2v/lightx2v_I2V_14B_480p_cfg_step_4", adapter_name="lightx2v_2") | |
| video_pipe.transformer.set_adapters(["lightx2v"], adapter_weights=[1.0]) | |
| video_pipe.transformer_2.set_adapters(["lightx2v_2"], adapter_weights=[1.0]) | |
| print("Lightning LoRA loaded successfully") | |
| except Exception as e: | |
| print(f"Warning: Could not load Lightning LoRA: {e}") | |
| # Continue without LoRA | |
| # Clear memory again | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Try to optimize if module available | |
| try: | |
| from optimization import optimize_pipeline_ | |
| print("Optimizing pipeline...") | |
| optimize_pipeline_(video_pipe, | |
| image=Image.new('RGB', (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT)), | |
| prompt='prompt', | |
| height=LANDSCAPE_HEIGHT, | |
| width=LANDSCAPE_WIDTH, | |
| num_frames=MAX_FRAMES_MODEL, | |
| ) | |
| print("Pipeline optimization complete") | |
| except ImportError: | |
| print("Optimization module not found, running without optimization") | |
| except Exception as e: | |
| print(f"Warning: Optimization failed: {e}") | |
| video_pipeline_ready = True | |
| print("Video pipeline initialized successfully!") | |
| except Exception as e: | |
| print(f"Error initializing video pipeline: {e}") | |
| video_pipe = None | |
| video_pipeline_ready = False | |
| # =========================== | |
| # Image Processing Functions | |
| # =========================== | |
| def upload_image_to_hosting(image): | |
| """Upload image to multiple hosting services with fallback""" | |
| # Method 1: Try imgbb.com | |
| try: | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| buffered.seek(0) | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| response = requests.post( | |
| "https://api.imgbb.com/1/upload", | |
| data={ | |
| 'key': '6d207e02198a847aa98d0a2a901485a5', | |
| 'image': img_base64, | |
| } | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if data.get('success'): | |
| return data['data']['url'] | |
| except: | |
| pass | |
| # Method 2: Try 0x0.st | |
| try: | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| buffered.seek(0) | |
| files = {'file': ('image.png', buffered, 'image/png')} | |
| response = requests.post("https://0x0.st", files=files) | |
| if response.status_code == 200: | |
| return response.text.strip() | |
| except: | |
| pass | |
| # Method 3: Fallback to base64 | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| buffered.seek(0) | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_base64}" | |
| def process_images(prompt, image1, image2=None): | |
| """Process uploaded images with Replicate API""" | |
| if not image1: | |
| return None, "Please upload at least one image", None | |
| if not os.getenv('REPLICATE_API_TOKEN'): | |
| return None, "Please set REPLICATE_API_TOKEN", None | |
| try: | |
| image_urls = [] | |
| # Upload images | |
| url1 = upload_image_to_hosting(image1) | |
| image_urls.append(url1) | |
| if image2: | |
| url2 = upload_image_to_hosting(image2) | |
| image_urls.append(url2) | |
| # Run the model | |
| output = replicate.run( | |
| "google/nano-banana", | |
| input={ | |
| "prompt": prompt, | |
| "image_input": image_urls | |
| } | |
| ) | |
| if output is None: | |
| return None, "No output received", None | |
| # Get the generated image | |
| img = None | |
| try: | |
| if hasattr(output, 'read'): | |
| img_data = output.read() | |
| img = Image.open(BytesIO(img_data)) | |
| except: | |
| pass | |
| if img is None: | |
| try: | |
| if hasattr(output, 'url'): | |
| output_url = output.url() | |
| response = requests.get(output_url, timeout=30) | |
| if response.status_code == 200: | |
| img = Image.open(BytesIO(response.content)) | |
| except: | |
| pass | |
| if img is None: | |
| output_url = None | |
| if isinstance(output, str): | |
| output_url = output | |
| elif isinstance(output, list) and len(output) > 0: | |
| output_url = output[0] | |
| if output_url: | |
| response = requests.get(output_url, timeout=30) | |
| if response.status_code == 200: | |
| img = Image.open(BytesIO(response.content)) | |
| if img: | |
| return img, "✨ Image generated successfully! You can now generate a video from this image.", img | |
| else: | |
| return None, "Could not process output", None | |
| except Exception as e: | |
| return None, f"Error: {str(e)[:100]}", None | |
| # =========================== | |
| # Video Generation Functions | |
| # =========================== | |
| def resize_image_for_video(image: Image.Image) -> Image.Image: | |
| """Resize image for video generation""" | |
| if image.height > image.width: | |
| transposed = image.transpose(Image.Transpose.ROTATE_90) | |
| resized = resize_image_landscape(transposed) | |
| return resized.transpose(Image.Transpose.ROTATE_270) | |
| return resize_image_landscape(image) | |
| def resize_image_landscape(image: Image.Image) -> Image.Image: | |
| """Resize landscape image to target dimensions""" | |
| target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT | |
| width, height = image.size | |
| in_aspect = width / height | |
| if in_aspect > target_aspect: | |
| new_width = round(height * target_aspect) | |
| left = (width - new_width) // 2 | |
| image = image.crop((left, 0, left + new_width, height)) | |
| else: | |
| new_height = round(width / target_aspect) | |
| top = (height - new_height) // 2 | |
| image = image.crop((0, top, width, top + new_height)) | |
| return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS) | |
| def get_duration(input_image, prompt, steps, negative_prompt, duration_seconds, guidance_scale, guidance_scale_2, seed, randomize_seed): | |
| # Shorter duration for stability | |
| return min(60, int(steps) * 10) | |
| def generate_video( | |
| input_image, | |
| prompt, | |
| steps=4, | |
| negative_prompt=default_negative_prompt, | |
| duration_seconds=2.0, # Reduced default | |
| guidance_scale=1, | |
| guidance_scale_2=1, | |
| seed=42, | |
| randomize_seed=False, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Generate a video from an input image""" | |
| if input_image is None: | |
| raise gr.Error("Please generate or upload an image first.") | |
| try: | |
| # Initialize pipeline if needed (simplified) | |
| global video_pipe | |
| if video_pipe is None: | |
| print("Initializing video pipeline...") | |
| video_pipe = WanImageToVideoPipeline.from_pretrained( | |
| VIDEO_MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| variant="fp16", | |
| use_safetensors=True | |
| ).to('cuda') | |
| # Load Lightning LoRA for faster generation | |
| try: | |
| video_pipe.load_lora_weights("Kijai/WanVideo_comfy", weight_name="Wan22-Lightning-4-cfg1_bf16_v0.9.safetensors") | |
| video_pipe.fuse_lora(lora_scale=1.0) | |
| except: | |
| pass | |
| # Clear cache before generation | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Ensure frames are divisible by 4 and limit to reasonable range | |
| num_frames = int(round(duration_seconds * FIXED_FPS)) | |
| num_frames = np.clip(num_frames, 9, 33) # Limit to 0.5-2 seconds | |
| num_frames = ((num_frames - 1) // 4) * 4 + 1 | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
| # Resize image | |
| resized_image = resize_image_for_video(input_image) | |
| # Generate with reduced settings | |
| with torch.inference_mode(): | |
| with torch.autocast('cuda', dtype=torch.bfloat16): | |
| output_frames_list = video_pipe( | |
| image=resized_image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=resized_image.height, | |
| width=resized_image.width, | |
| num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), | |
| guidance_scale_2=float(guidance_scale_2), | |
| num_inference_steps=int(steps), | |
| generator=torch.Generator(device="cuda").manual_seed(current_seed), | |
| ).frames[0] | |
| # Clear cache after generation | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Save video | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
| video_path = tmpfile.name | |
| export_to_video(output_frames_list, video_path, fps=FIXED_FPS) | |
| return video_path, current_seed, f"🎬 Video generated successfully! ({num_frames} frames)" | |
| except RuntimeError as e: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| if "out of memory" in str(e).lower(): | |
| raise gr.Error("GPU memory exceeded. Try reducing duration to 1-2 seconds and steps to 4.") | |
| else: | |
| raise gr.Error(f"GPU error: {str(e)[:100]}") | |
| except Exception as e: | |
| raise gr.Error(f"Error: {str(e)[:200]}") | |
| # =========================== | |
| # Enhanced CSS | |
| # =========================== | |
| css = """ | |
| .gradio-container { | |
| background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; | |
| min-height: 100vh; | |
| } | |
| .header-container { | |
| background: linear-gradient(135deg, #ffd93d 0%, #ffb347 100%); | |
| padding: 2.5rem; | |
| border-radius: 24px; | |
| margin-bottom: 2.5rem; | |
| box-shadow: 0 20px 60px rgba(255, 179, 71, 0.25); | |
| } | |
| .logo-text { | |
| font-size: 3.5rem; | |
| font-weight: 900; | |
| color: #2d3436; | |
| text-align: center; | |
| margin: 0; | |
| letter-spacing: -2px; | |
| } | |
| .subtitle { | |
| color: #2d3436; | |
| text-align: center; | |
| font-size: 1.2rem; | |
| margin-top: 0.5rem; | |
| opacity: 0.9; | |
| font-weight: 600; | |
| } | |
| .main-content { | |
| background: rgba(255, 255, 255, 0.95); | |
| backdrop-filter: blur(20px); | |
| border-radius: 24px; | |
| padding: 2.5rem; | |
| box-shadow: 0 10px 40px rgba(0, 0, 0, 0.08); | |
| margin-bottom: 2rem; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(135deg, #ffd93d 0%, #ffb347 100%) !important; | |
| border: none !important; | |
| color: #2d3436 !important; | |
| font-weight: 700 !important; | |
| font-size: 1.1rem !important; | |
| padding: 1.2rem 2rem !important; | |
| border-radius: 14px !important; | |
| transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1) !important; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| width: 100%; | |
| margin-top: 1rem !important; | |
| } | |
| .gr-button-primary:hover { | |
| transform: translateY(-3px) !important; | |
| box-shadow: 0 15px 40px rgba(255, 179, 71, 0.35) !important; | |
| } | |
| .gr-button-secondary { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: 700 !important; | |
| font-size: 1.1rem !important; | |
| padding: 1.2rem 2rem !important; | |
| border-radius: 14px !important; | |
| transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1) !important; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| width: 100%; | |
| margin-top: 1rem !important; | |
| } | |
| .gr-button-secondary:hover { | |
| transform: translateY(-3px) !important; | |
| box-shadow: 0 15px 40px rgba(102, 126, 234, 0.35) !important; | |
| } | |
| .section-title { | |
| font-size: 1.8rem; | |
| font-weight: 800; | |
| color: #2d3436; | |
| margin-bottom: 1rem; | |
| padding-bottom: 0.5rem; | |
| border-bottom: 3px solid #ffd93d; | |
| } | |
| .status-text { | |
| font-family: 'SF Mono', 'Monaco', monospace; | |
| color: #00b894; | |
| font-size: 0.9rem; | |
| } | |
| .image-container { | |
| border-radius: 14px !important; | |
| overflow: hidden; | |
| border: 2px solid #e1e8ed !important; | |
| background: #fafbfc !important; | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| """ | |
| # =========================== | |
| # Gradio Interface | |
| # =========================== | |
| with gr.Blocks(css=css, theme=gr.themes.Base()) as demo: | |
| # Shared state for passing image between tabs | |
| generated_image_state = gr.State(None) | |
| with gr.Column(elem_classes="header-container"): | |
| gr.HTML(""" | |
| <h1 class="logo-text">🍌 Nano Banana + Video</h1> | |
| <p class="subtitle">AI-Powered Image Style Transfer with Video Generation</p> | |
| <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-top: 20px;"> | |
| <a href="https://huggingface.co/spaces/openfree/Nano-Banana-Upscale" target="_blank"> | |
| <img src="https://img.shields.io/static/v1?label=NANO%20BANANA&message=UPSCALE&color=%230000ff&labelColor=%23800080&logo=GOOGLE&logoColor=white&style=for-the-badge" alt="Nano Banana Upscale"> | |
| </a> | |
| <a href="https://discord.gg/openfreeai" target="_blank"> | |
| <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord Openfree AI"> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Image Generation | |
| with gr.TabItem("🎨 Step 1: Generate Image"): | |
| with gr.Column(elem_classes="main-content"): | |
| gr.HTML('<h2 class="section-title">🎨 Image Style Transfer</h2>') | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| style_prompt = gr.Textbox( | |
| label="Style Description", | |
| placeholder="Describe your style...", | |
| lines=3, | |
| value="Make the sheets in the style of the logo. Make the scene natural.", | |
| ) | |
| with gr.Row(equal_height=True): | |
| image1 = gr.Image( | |
| label="Primary Image", | |
| type="pil", | |
| height=200, | |
| elem_classes="image-container" | |
| ) | |
| image2 = gr.Image( | |
| label="Secondary Image (Optional)", | |
| type="pil", | |
| height=200, | |
| elem_classes="image-container" | |
| ) | |
| generate_img_btn = gr.Button( | |
| "Generate Image ✨", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="Generated Result", | |
| type="pil", | |
| height=420, | |
| elem_classes="image-container" | |
| ) | |
| img_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=1, | |
| elem_classes="status-text", | |
| value="Ready to generate image..." | |
| ) | |
| send_to_video_btn = gr.Button( | |
| "Send to Video Generation →", | |
| variant="secondary", | |
| size="lg", | |
| visible=False | |
| ) | |
| # Tab 2: Video Generation | |
| with gr.TabItem("🎬 Step 2: Generate Video"): | |
| with gr.Column(elem_classes="main-content"): | |
| gr.HTML('<h2 class="section-title">🎬 Video Generation from Image</h2>') | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input_image = gr.Image( | |
| type="pil", | |
| label="Input Image (from Step 1 or upload new)", | |
| elem_classes="image-container" | |
| ) | |
| video_prompt = gr.Textbox( | |
| label="Animation Prompt", | |
| value=default_prompt_i2v, | |
| lines=3 | |
| ) | |
| duration_input = gr.Slider( | |
| minimum=0.5, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.5, | |
| label="Duration (seconds)", | |
| info="Shorter videos use less memory" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| video_negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value=default_negative_prompt, | |
| lines=3 | |
| ) | |
| video_seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed", | |
| value=True | |
| ) | |
| steps_slider = gr.Slider( | |
| minimum=1, | |
| maximum=8, | |
| step=1, | |
| value=4, | |
| label="Inference Steps (4 recommended)" | |
| ) | |
| guidance_1 = gr.Slider( | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=1, | |
| label="Guidance Scale - High Noise" | |
| ) | |
| guidance_2 = gr.Slider( | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=1, | |
| label="Guidance Scale - Low Noise" | |
| ) | |
| generate_video_btn = gr.Button( | |
| "Generate Video 🎬", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| video_output = gr.Video( | |
| label="Generated Video", | |
| autoplay=True | |
| ) | |
| video_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=1, | |
| elem_classes="status-text", | |
| value="Ready to generate video..." | |
| ) | |
| # Event Handlers | |
| def on_image_generated(prompt, img1, img2): | |
| img, status, state_img = process_images(prompt, img1, img2) | |
| if img: | |
| return img, status, state_img, gr.update(visible=True) | |
| return img, status, state_img, gr.update(visible=False) | |
| def send_image_to_video(img): | |
| if img: | |
| return img, "Image loaded! Ready to generate video." | |
| return None, "No image to send." | |
| # Image generation events | |
| generate_img_btn.click( | |
| fn=on_image_generated, | |
| inputs=[style_prompt, image1, image2], | |
| outputs=[output_image, img_status, generated_image_state, send_to_video_btn] | |
| ) | |
| # Send to video tab | |
| send_to_video_btn.click( | |
| fn=send_image_to_video, | |
| inputs=[generated_image_state], | |
| outputs=[video_input_image, video_status] | |
| ) | |
| # Video generation events | |
| video_inputs = [ | |
| video_input_image, video_prompt, steps_slider, | |
| video_negative_prompt, duration_input, | |
| guidance_1, guidance_2, video_seed, randomize_seed | |
| ] | |
| def generate_video_wrapper(img, prompt, steps, neg_prompt, duration, g1, g2, seed, rand_seed): | |
| try: | |
| # Pass steps as first argument for GPU duration | |
| video_path, new_seed, status = generate_video( | |
| img, prompt, steps, neg_prompt, duration, g1, g2, seed, rand_seed | |
| ) | |
| return video_path, new_seed, status | |
| except Exception as e: | |
| return None, seed, f"Error: {str(e)}" | |
| generate_video_btn.click( | |
| fn=generate_video_wrapper, | |
| inputs=video_inputs, | |
| outputs=[video_output, video_seed, video_status] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| # Don't initialize video pipeline on startup to avoid blocking | |
| print("Starting application...") | |
| print("Note: Video pipeline will initialize on first use") | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |