ginipick's picture
Update app.py
cea104c verified
raw
history blame
17.2 kB
import gradio as gr
import replicate
import os
from PIL import Image
import requests
from io import BytesIO
import time
import tempfile
import base64
import numpy as np
import random
import gc
# GPU 관련 임포트는 나중에 조건부로 처리
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
print("Warning: PyTorch not available. Video generation will be disabled.")
# ===========================
# Configuration
# ===========================
# Set up Replicate API key
os.environ['REPLICATE_API_TOKEN'] = os.getenv('REPLICATE_API_TOKEN')
# Video Model Configuration
VIDEO_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
LANDSCAPE_WIDTH = 832
LANDSCAPE_HEIGHT = 480
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
default_negative_prompt = "static, still, no motion, frozen"
# ===========================
# Initialize Video Pipeline (Lazy Loading)
# ===========================
video_pipe = None
video_pipeline_ready = False
def lazy_import_video_dependencies():
"""Lazy import video dependencies only when needed"""
global video_pipe, video_pipeline_ready
if not TORCH_AVAILABLE:
raise gr.Error("PyTorch is not installed. Video generation is not available.")
try:
# Try to import video pipeline dependencies
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
return WanImageToVideoPipeline, WanTransformer3DModel, export_to_video
except ImportError as e:
print(f"Warning: Video dependencies not available: {e}")
return None, None, None
# ===========================
# Image Processing Functions
# ===========================
def upload_image_to_hosting(image):
"""Upload image to multiple hosting services with fallback"""
# Method 1: Try imgbb.com
try:
buffered = BytesIO()
image.save(buffered, format="PNG")
buffered.seek(0)
img_base64 = base64.b64encode(buffered.getvalue()).decode()
response = requests.post(
"https://api.imgbb.com/1/upload",
data={
'key': '6d207e02198a847aa98d0a2a901485a5',
'image': img_base64,
},
timeout=10
)
if response.status_code == 200:
data = response.json()
if data.get('success'):
return data['data']['url']
except Exception as e:
print(f"imgbb upload failed: {e}")
# Method 2: Try 0x0.st
try:
buffered = BytesIO()
image.save(buffered, format="PNG")
buffered.seek(0)
files = {'file': ('image.png', buffered, 'image/png')}
response = requests.post("https://0x0.st", files=files, timeout=10)
if response.status_code == 200:
return response.text.strip()
except Exception as e:
print(f"0x0.st upload failed: {e}")
# Method 3: Fallback to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
buffered.seek(0)
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_base64}"
def process_images(prompt, image1, image2=None):
"""Process uploaded images with Replicate API"""
if not image1:
return None, "Please upload at least one image", None
if not os.getenv('REPLICATE_API_TOKEN'):
return None, "Please set REPLICATE_API_TOKEN", None
try:
image_urls = []
# Upload images
url1 = upload_image_to_hosting(image1)
image_urls.append(url1)
if image2:
url2 = upload_image_to_hosting(image2)
image_urls.append(url2)
# Run the model (using a placeholder model name - replace with actual)
# Note: "google/nano-banana" doesn't exist - replace with actual model
output = replicate.run(
"stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
input={
"prompt": prompt,
"image": url1 if len(image_urls) == 1 else None,
"width": 1024,
"height": 1024
}
)
if output is None:
return None, "No output received", None
# Get the generated image
img = None
# Handle different output formats
if isinstance(output, list) and len(output) > 0:
output_url = output[0]
elif isinstance(output, str):
output_url = output
else:
output_url = str(output)
if output_url:
response = requests.get(output_url, timeout=30)
if response.status_code == 200:
img = Image.open(BytesIO(response.content))
if img:
return img, "✨ Image generated successfully!", img
else:
return None, "Could not process output", None
except Exception as e:
return None, f"Error: {str(e)[:200]}", None
# ===========================
# Video Generation Functions (Simplified)
# ===========================
def resize_image_for_video(image: Image.Image) -> Image.Image:
"""Resize image for video generation"""
target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT
width, height = image.size
in_aspect = width / height
if in_aspect > target_aspect:
new_width = round(height * target_aspect)
left = (width - new_width) // 2
image = image.crop((left, 0, left + new_width, height))
else:
new_height = round(width / target_aspect)
top = (height - new_height) // 2
image = image.crop((0, top, width, top + new_height))
return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
def generate_video(
input_image,
prompt,
steps=4,
negative_prompt=default_negative_prompt,
duration_seconds=1.5,
guidance_scale=1,
guidance_scale_2=1,
seed=42,
randomize_seed=False,
):
"""Generate a video from an input image (simplified version)"""
if input_image is None:
raise gr.Error("Please generate or upload an image first.")
if not TORCH_AVAILABLE:
raise gr.Error("Video generation is not available. PyTorch is not installed.")
try:
# Import dependencies
video_deps = lazy_import_video_dependencies()
if not all(video_deps):
raise gr.Error("Video generation dependencies are not available.")
WanImageToVideoPipeline, WanTransformer3DModel, export_to_video = video_deps
global video_pipe
# Simple initialization without complex optimizations
if video_pipe is None:
print("Initializing video pipeline (simplified)...")
# Clear GPU memory first
if TORCH_AVAILABLE:
torch.cuda.empty_cache()
gc.collect()
# Basic pipeline loading
try:
video_pipe = WanImageToVideoPipeline.from_pretrained(
VIDEO_MODEL_ID,
torch_dtype=torch.float16 if TORCH_AVAILABLE else None,
low_cpu_mem_usage=True,
device_map="auto"
)
print("Video pipeline loaded")
except Exception as e:
print(f"Failed to load video pipeline: {e}")
raise gr.Error("Could not load video model. Please try again later.")
# Prepare video generation
num_frames = min(17, int(round(duration_seconds * FIXED_FPS))) # Limit frames
num_frames = ((num_frames - 1) // 4) * 4 + 1 # Ensure divisible by 4
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
# Resize image
resized_image = resize_image_for_video(input_image)
# Generate video with minimal settings
print(f"Generating {num_frames} frames...")
if TORCH_AVAILABLE:
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(current_seed)
else:
generator = None
output_frames_list = video_pipe(
image=resized_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=LANDSCAPE_HEIGHT,
width=LANDSCAPE_WIDTH,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
generator=generator,
).frames[0]
# Save video
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
return video_path, current_seed, f"🎬 Video generated! ({num_frames} frames)"
except Exception as e:
if TORCH_AVAILABLE:
torch.cuda.empty_cache()
gc.collect()
error_msg = str(e)[:200]
if "out of memory" in error_msg.lower():
return None, seed, "GPU memory exceeded. Try reducing duration and steps."
return None, seed, f"Error: {error_msg}"
# ===========================
# Simple CSS
# ===========================
css = """
.gradio-container {
max-width: 1200px;
margin: 0 auto;
}
.header-container {
background: linear-gradient(135deg, #ffd93d 0%, #ffb347 100%);
padding: 2rem;
border-radius: 12px;
margin-bottom: 2rem;
text-align: center;
}
.logo-text {
font-size: 2.5rem;
font-weight: bold;
color: #2d3436;
margin: 0;
}
.subtitle {
color: #2d3436;
font-size: 1rem;
margin-top: 0.5rem;
}
"""
# ===========================
# Gradio Interface (Simplified)
# ===========================
def create_demo():
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
# Shared state
generated_image_state = gr.State(None)
gr.HTML("""
<div class="header-container">
<h1 class="logo-text">🍌 Nano Banana + Video</h1>
<p class="subtitle">AI-Powered Image Generation with Video Creation</p>
</div>
""")
with gr.Tabs():
# Tab 1: Image Generation
with gr.TabItem("🎨 Step 1: Generate Image"):
with gr.Row():
with gr.Column():
style_prompt = gr.Textbox(
label="Style Description",
placeholder="Describe your style...",
lines=3,
value="A beautiful landscape in anime style"
)
image1 = gr.Image(
label="Reference Image (Optional)",
type="pil"
)
image2 = gr.Image(
label="Secondary Image (Optional)",
type="pil"
)
generate_img_btn = gr.Button(
"Generate Image ✨",
variant="primary"
)
with gr.Column():
output_image = gr.Image(
label="Generated Result",
type="pil"
)
img_status = gr.Textbox(
label="Status",
interactive=False,
value="Ready..."
)
send_to_video_btn = gr.Button(
"Send to Video Generation →",
variant="secondary",
visible=False
)
# Tab 2: Video Generation
with gr.TabItem("🎬 Step 2: Generate Video"):
with gr.Row():
with gr.Column():
video_input_image = gr.Image(
type="pil",
label="Input Image"
)
video_prompt = gr.Textbox(
label="Animation Prompt",
value=default_prompt_i2v
)
duration_input = gr.Slider(
minimum=0.5,
maximum=2.0,
step=0.5,
value=1.0,
label="Duration (seconds)"
)
steps_slider = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=4,
label="Inference Steps"
)
generate_video_btn = gr.Button(
"Generate Video 🎬",
variant="primary"
)
with gr.Column():
video_output = gr.Video(
label="Generated Video",
autoplay=True
)
video_status = gr.Textbox(
label="Status",
interactive=False,
value="Ready..."
)
# Event Handlers
def on_image_generated(prompt, img1, img2):
img, status, state_img = process_images(prompt, img1, img2)
if img:
return img, status, state_img, gr.update(visible=True)
return img, status, state_img, gr.update(visible=False)
def send_image_to_video(img):
if img:
return img, "Image loaded!"
return None, "No image to send."
# Wire up events
generate_img_btn.click(
fn=on_image_generated,
inputs=[style_prompt, image1, image2],
outputs=[output_image, img_status, generated_image_state, send_to_video_btn]
)
send_to_video_btn.click(
fn=send_image_to_video,
inputs=[generated_image_state],
outputs=[video_input_image, video_status]
)
# Simplified video generation
def generate_video_wrapper(img, prompt, duration, steps):
if not TORCH_AVAILABLE:
return None, "Video generation requires PyTorch. Please install it first."
try:
video_path, seed, status = generate_video(
img, prompt, steps=steps, duration_seconds=duration
)
return video_path, status
except Exception as e:
return None, f"Error: {str(e)[:100]}"
generate_video_btn.click(
fn=generate_video_wrapper,
inputs=[video_input_image, video_prompt, duration_input, steps_slider],
outputs=[video_output, video_status]
)
return demo
# ===========================
# Main Launch
# ===========================
if __name__ == "__main__":
print("=" * 50)
print("Starting Nano Banana + Video Application")
print("=" * 50)
# Check environment
if not os.getenv('REPLICATE_API_TOKEN'):
print("Warning: REPLICATE_API_TOKEN not set. Image generation may not work.")
if not TORCH_AVAILABLE:
print("Warning: PyTorch not available. Video generation will be disabled.")
print("To enable video generation, install PyTorch: pip install torch")
try:
# Create and launch demo
demo = create_demo()
demo.launch(
share=False, # Set to True if you want a public link
server_name="0.0.0.0",
server_port=7860,
show_error=True,
debug=False # Set to True for debugging
)
except Exception as e:
print(f"Failed to launch application: {e}")
print("Please check your environment and dependencies.")