Spaces:
Paused
Paused
| import gradio as gr | |
| import replicate | |
| import os | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import time | |
| import tempfile | |
| import base64 | |
| import numpy as np | |
| import random | |
| import gc | |
| # GPU 관련 임포트는 나중에 조건부로 처리 | |
| try: | |
| import torch | |
| TORCH_AVAILABLE = True | |
| except ImportError: | |
| TORCH_AVAILABLE = False | |
| print("Warning: PyTorch not available. Video generation will be disabled.") | |
| # =========================== | |
| # Configuration | |
| # =========================== | |
| # Set up Replicate API key | |
| os.environ['REPLICATE_API_TOKEN'] = os.getenv('REPLICATE_API_TOKEN') | |
| # Video Model Configuration | |
| VIDEO_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" | |
| LANDSCAPE_WIDTH = 832 | |
| LANDSCAPE_HEIGHT = 480 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| FIXED_FPS = 16 | |
| MIN_FRAMES_MODEL = 8 | |
| MAX_FRAMES_MODEL = 81 | |
| MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1) | |
| MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1) | |
| default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation" | |
| default_negative_prompt = "static, still, no motion, frozen" | |
| # =========================== | |
| # Initialize Video Pipeline (Lazy Loading) | |
| # =========================== | |
| video_pipe = None | |
| video_pipeline_ready = False | |
| def lazy_import_video_dependencies(): | |
| """Lazy import video dependencies only when needed""" | |
| global video_pipe, video_pipeline_ready | |
| if not TORCH_AVAILABLE: | |
| raise gr.Error("PyTorch is not installed. Video generation is not available.") | |
| try: | |
| # Try to import video pipeline dependencies | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
| from diffusers.models.transformers.transformer_wan import WanTransformer3DModel | |
| from diffusers.utils.export_utils import export_to_video | |
| return WanImageToVideoPipeline, WanTransformer3DModel, export_to_video | |
| except ImportError as e: | |
| print(f"Warning: Video dependencies not available: {e}") | |
| return None, None, None | |
| # =========================== | |
| # Image Processing Functions | |
| # =========================== | |
| def upload_image_to_hosting(image): | |
| """Upload image to multiple hosting services with fallback""" | |
| # Method 1: Try imgbb.com | |
| try: | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| buffered.seek(0) | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| response = requests.post( | |
| "https://api.imgbb.com/1/upload", | |
| data={ | |
| 'key': '6d207e02198a847aa98d0a2a901485a5', | |
| 'image': img_base64, | |
| }, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if data.get('success'): | |
| return data['data']['url'] | |
| except Exception as e: | |
| print(f"imgbb upload failed: {e}") | |
| # Method 2: Try 0x0.st | |
| try: | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| buffered.seek(0) | |
| files = {'file': ('image.png', buffered, 'image/png')} | |
| response = requests.post("https://0x0.st", files=files, timeout=10) | |
| if response.status_code == 200: | |
| return response.text.strip() | |
| except Exception as e: | |
| print(f"0x0.st upload failed: {e}") | |
| # Method 3: Fallback to base64 | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| buffered.seek(0) | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_base64}" | |
| def process_images(prompt, image1, image2=None): | |
| """Process uploaded images with Replicate API""" | |
| if not image1: | |
| return None, "Please upload at least one image", None | |
| if not os.getenv('REPLICATE_API_TOKEN'): | |
| return None, "Please set REPLICATE_API_TOKEN", None | |
| try: | |
| image_urls = [] | |
| # Upload images | |
| url1 = upload_image_to_hosting(image1) | |
| image_urls.append(url1) | |
| if image2: | |
| url2 = upload_image_to_hosting(image2) | |
| image_urls.append(url2) | |
| # Run the model (using a placeholder model name - replace with actual) | |
| # Note: "google/nano-banana" doesn't exist - replace with actual model | |
| output = replicate.run( | |
| "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", | |
| input={ | |
| "prompt": prompt, | |
| "image": url1 if len(image_urls) == 1 else None, | |
| "width": 1024, | |
| "height": 1024 | |
| } | |
| ) | |
| if output is None: | |
| return None, "No output received", None | |
| # Get the generated image | |
| img = None | |
| # Handle different output formats | |
| if isinstance(output, list) and len(output) > 0: | |
| output_url = output[0] | |
| elif isinstance(output, str): | |
| output_url = output | |
| else: | |
| output_url = str(output) | |
| if output_url: | |
| response = requests.get(output_url, timeout=30) | |
| if response.status_code == 200: | |
| img = Image.open(BytesIO(response.content)) | |
| if img: | |
| return img, "✨ Image generated successfully!", img | |
| else: | |
| return None, "Could not process output", None | |
| except Exception as e: | |
| return None, f"Error: {str(e)[:200]}", None | |
| # =========================== | |
| # Video Generation Functions (Simplified) | |
| # =========================== | |
| def resize_image_for_video(image: Image.Image) -> Image.Image: | |
| """Resize image for video generation""" | |
| target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT | |
| width, height = image.size | |
| in_aspect = width / height | |
| if in_aspect > target_aspect: | |
| new_width = round(height * target_aspect) | |
| left = (width - new_width) // 2 | |
| image = image.crop((left, 0, left + new_width, height)) | |
| else: | |
| new_height = round(width / target_aspect) | |
| top = (height - new_height) // 2 | |
| image = image.crop((0, top, width, top + new_height)) | |
| return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS) | |
| def generate_video( | |
| input_image, | |
| prompt, | |
| steps=4, | |
| negative_prompt=default_negative_prompt, | |
| duration_seconds=1.5, | |
| guidance_scale=1, | |
| guidance_scale_2=1, | |
| seed=42, | |
| randomize_seed=False, | |
| ): | |
| """Generate a video from an input image (simplified version)""" | |
| if input_image is None: | |
| raise gr.Error("Please generate or upload an image first.") | |
| if not TORCH_AVAILABLE: | |
| raise gr.Error("Video generation is not available. PyTorch is not installed.") | |
| try: | |
| # Import dependencies | |
| video_deps = lazy_import_video_dependencies() | |
| if not all(video_deps): | |
| raise gr.Error("Video generation dependencies are not available.") | |
| WanImageToVideoPipeline, WanTransformer3DModel, export_to_video = video_deps | |
| global video_pipe | |
| # Simple initialization without complex optimizations | |
| if video_pipe is None: | |
| print("Initializing video pipeline (simplified)...") | |
| # Clear GPU memory first | |
| if TORCH_AVAILABLE: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Basic pipeline loading | |
| try: | |
| video_pipe = WanImageToVideoPipeline.from_pretrained( | |
| VIDEO_MODEL_ID, | |
| torch_dtype=torch.float16 if TORCH_AVAILABLE else None, | |
| low_cpu_mem_usage=True, | |
| device_map="auto" | |
| ) | |
| print("Video pipeline loaded") | |
| except Exception as e: | |
| print(f"Failed to load video pipeline: {e}") | |
| raise gr.Error("Could not load video model. Please try again later.") | |
| # Prepare video generation | |
| num_frames = min(17, int(round(duration_seconds * FIXED_FPS))) # Limit frames | |
| num_frames = ((num_frames - 1) // 4) * 4 + 1 # Ensure divisible by 4 | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
| # Resize image | |
| resized_image = resize_image_for_video(input_image) | |
| # Generate video with minimal settings | |
| print(f"Generating {num_frames} frames...") | |
| if TORCH_AVAILABLE: | |
| generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(current_seed) | |
| else: | |
| generator = None | |
| output_frames_list = video_pipe( | |
| image=resized_image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=LANDSCAPE_HEIGHT, | |
| width=LANDSCAPE_WIDTH, | |
| num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(steps), | |
| generator=generator, | |
| ).frames[0] | |
| # Save video | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
| video_path = tmpfile.name | |
| export_to_video(output_frames_list, video_path, fps=FIXED_FPS) | |
| return video_path, current_seed, f"🎬 Video generated! ({num_frames} frames)" | |
| except Exception as e: | |
| if TORCH_AVAILABLE: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| error_msg = str(e)[:200] | |
| if "out of memory" in error_msg.lower(): | |
| return None, seed, "GPU memory exceeded. Try reducing duration and steps." | |
| return None, seed, f"Error: {error_msg}" | |
| # =========================== | |
| # Simple CSS | |
| # =========================== | |
| css = """ | |
| .gradio-container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| .header-container { | |
| background: linear-gradient(135deg, #ffd93d 0%, #ffb347 100%); | |
| padding: 2rem; | |
| border-radius: 12px; | |
| margin-bottom: 2rem; | |
| text-align: center; | |
| } | |
| .logo-text { | |
| font-size: 2.5rem; | |
| font-weight: bold; | |
| color: #2d3436; | |
| margin: 0; | |
| } | |
| .subtitle { | |
| color: #2d3436; | |
| font-size: 1rem; | |
| margin-top: 0.5rem; | |
| } | |
| """ | |
| # =========================== | |
| # Gradio Interface (Simplified) | |
| # =========================== | |
| def create_demo(): | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| # Shared state | |
| generated_image_state = gr.State(None) | |
| gr.HTML(""" | |
| <div class="header-container"> | |
| <h1 class="logo-text">🍌 Nano Banana + Video</h1> | |
| <p class="subtitle">AI-Powered Image Generation with Video Creation</p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Image Generation | |
| with gr.TabItem("🎨 Step 1: Generate Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| style_prompt = gr.Textbox( | |
| label="Style Description", | |
| placeholder="Describe your style...", | |
| lines=3, | |
| value="A beautiful landscape in anime style" | |
| ) | |
| image1 = gr.Image( | |
| label="Reference Image (Optional)", | |
| type="pil" | |
| ) | |
| image2 = gr.Image( | |
| label="Secondary Image (Optional)", | |
| type="pil" | |
| ) | |
| generate_img_btn = gr.Button( | |
| "Generate Image ✨", | |
| variant="primary" | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Generated Result", | |
| type="pil" | |
| ) | |
| img_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| value="Ready..." | |
| ) | |
| send_to_video_btn = gr.Button( | |
| "Send to Video Generation →", | |
| variant="secondary", | |
| visible=False | |
| ) | |
| # Tab 2: Video Generation | |
| with gr.TabItem("🎬 Step 2: Generate Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input_image = gr.Image( | |
| type="pil", | |
| label="Input Image" | |
| ) | |
| video_prompt = gr.Textbox( | |
| label="Animation Prompt", | |
| value=default_prompt_i2v | |
| ) | |
| duration_input = gr.Slider( | |
| minimum=0.5, | |
| maximum=2.0, | |
| step=0.5, | |
| value=1.0, | |
| label="Duration (seconds)" | |
| ) | |
| steps_slider = gr.Slider( | |
| minimum=1, | |
| maximum=8, | |
| step=1, | |
| value=4, | |
| label="Inference Steps" | |
| ) | |
| generate_video_btn = gr.Button( | |
| "Generate Video 🎬", | |
| variant="primary" | |
| ) | |
| with gr.Column(): | |
| video_output = gr.Video( | |
| label="Generated Video", | |
| autoplay=True | |
| ) | |
| video_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| value="Ready..." | |
| ) | |
| # Event Handlers | |
| def on_image_generated(prompt, img1, img2): | |
| img, status, state_img = process_images(prompt, img1, img2) | |
| if img: | |
| return img, status, state_img, gr.update(visible=True) | |
| return img, status, state_img, gr.update(visible=False) | |
| def send_image_to_video(img): | |
| if img: | |
| return img, "Image loaded!" | |
| return None, "No image to send." | |
| # Wire up events | |
| generate_img_btn.click( | |
| fn=on_image_generated, | |
| inputs=[style_prompt, image1, image2], | |
| outputs=[output_image, img_status, generated_image_state, send_to_video_btn] | |
| ) | |
| send_to_video_btn.click( | |
| fn=send_image_to_video, | |
| inputs=[generated_image_state], | |
| outputs=[video_input_image, video_status] | |
| ) | |
| # Simplified video generation | |
| def generate_video_wrapper(img, prompt, duration, steps): | |
| if not TORCH_AVAILABLE: | |
| return None, "Video generation requires PyTorch. Please install it first." | |
| try: | |
| video_path, seed, status = generate_video( | |
| img, prompt, steps=steps, duration_seconds=duration | |
| ) | |
| return video_path, status | |
| except Exception as e: | |
| return None, f"Error: {str(e)[:100]}" | |
| generate_video_btn.click( | |
| fn=generate_video_wrapper, | |
| inputs=[video_input_image, video_prompt, duration_input, steps_slider], | |
| outputs=[video_output, video_status] | |
| ) | |
| return demo | |
| # =========================== | |
| # Main Launch | |
| # =========================== | |
| if __name__ == "__main__": | |
| print("=" * 50) | |
| print("Starting Nano Banana + Video Application") | |
| print("=" * 50) | |
| # Check environment | |
| if not os.getenv('REPLICATE_API_TOKEN'): | |
| print("Warning: REPLICATE_API_TOKEN not set. Image generation may not work.") | |
| if not TORCH_AVAILABLE: | |
| print("Warning: PyTorch not available. Video generation will be disabled.") | |
| print("To enable video generation, install PyTorch: pip install torch") | |
| try: | |
| # Create and launch demo | |
| demo = create_demo() | |
| demo.launch( | |
| share=False, # Set to True if you want a public link | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| debug=False # Set to True for debugging | |
| ) | |
| except Exception as e: | |
| print(f"Failed to launch application: {e}") | |
| print("Please check your environment and dependencies.") |