ginipick commited on
Commit
e48246c
·
verified ·
1 Parent(s): cea104c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +322 -353
app.py CHANGED
@@ -7,18 +7,12 @@ from io import BytesIO
7
  import time
8
  import tempfile
9
  import base64
 
 
10
  import numpy as np
11
  import random
12
  import gc
13
 
14
- # GPU 관련 임포트는 나중에 조건부로 처리
15
- try:
16
- import torch
17
- TORCH_AVAILABLE = True
18
- except ImportError:
19
- TORCH_AVAILABLE = False
20
- print("Warning: PyTorch not available. Video generation will be disabled.")
21
-
22
  # ===========================
23
  # Configuration
24
  # ===========================
@@ -27,51 +21,23 @@ except ImportError:
27
  os.environ['REPLICATE_API_TOKEN'] = os.getenv('REPLICATE_API_TOKEN')
28
 
29
  # Video Model Configuration
30
- VIDEO_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
31
- LANDSCAPE_WIDTH = 832
32
- LANDSCAPE_HEIGHT = 480
33
  MAX_SEED = np.iinfo(np.int32).max
34
- FIXED_FPS = 16
35
  MIN_FRAMES_MODEL = 8
36
- MAX_FRAMES_MODEL = 81
37
- MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
38
- MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
39
 
40
- default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
41
- default_negative_prompt = "static, still, no motion, frozen"
42
-
43
- # ===========================
44
- # Initialize Video Pipeline (Lazy Loading)
45
- # ===========================
46
-
47
- video_pipe = None
48
- video_pipeline_ready = False
49
-
50
- def lazy_import_video_dependencies():
51
- """Lazy import video dependencies only when needed"""
52
- global video_pipe, video_pipeline_ready
53
-
54
- if not TORCH_AVAILABLE:
55
- raise gr.Error("PyTorch is not installed. Video generation is not available.")
56
-
57
- try:
58
- # Try to import video pipeline dependencies
59
- from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
60
- from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
61
- from diffusers.utils.export_utils import export_to_video
62
-
63
- return WanImageToVideoPipeline, WanTransformer3DModel, export_to_video
64
- except ImportError as e:
65
- print(f"Warning: Video dependencies not available: {e}")
66
- return None, None, None
67
 
68
  # ===========================
69
  # Image Processing Functions
70
  # ===========================
71
 
72
  def upload_image_to_hosting(image):
73
- """Upload image to multiple hosting services with fallback"""
74
- # Method 1: Try imgbb.com
75
  try:
76
  buffered = BytesIO()
77
  image.save(buffered, format="PNG")
@@ -84,7 +50,7 @@ def upload_image_to_hosting(image):
84
  'key': '6d207e02198a847aa98d0a2a901485a5',
85
  'image': img_base64,
86
  },
87
- timeout=10
88
  )
89
 
90
  if response.status_code == 200:
@@ -92,23 +58,9 @@ def upload_image_to_hosting(image):
92
  if data.get('success'):
93
  return data['data']['url']
94
  except Exception as e:
95
- print(f"imgbb upload failed: {e}")
96
 
97
- # Method 2: Try 0x0.st
98
- try:
99
- buffered = BytesIO()
100
- image.save(buffered, format="PNG")
101
- buffered.seek(0)
102
-
103
- files = {'file': ('image.png', buffered, 'image/png')}
104
- response = requests.post("https://0x0.st", files=files, timeout=10)
105
-
106
- if response.status_code == 200:
107
- return response.text.strip()
108
- except Exception as e:
109
- print(f"0x0.st upload failed: {e}")
110
-
111
- # Method 3: Fallback to base64
112
  buffered = BytesIO()
113
  image.save(buffered, format="PNG")
114
  buffered.seek(0)
@@ -116,193 +68,179 @@ def upload_image_to_hosting(image):
116
  return f"data:image/png;base64,{img_base64}"
117
 
118
  def process_images(prompt, image1, image2=None):
119
- """Process uploaded images with Replicate API"""
120
  if not image1:
121
  return None, "Please upload at least one image", None
122
 
123
  if not os.getenv('REPLICATE_API_TOKEN'):
124
- return None, "Please set REPLICATE_API_TOKEN", None
125
 
126
  try:
127
- image_urls = []
128
-
129
- # Upload images
130
  url1 = upload_image_to_hosting(image1)
131
- image_urls.append(url1)
132
-
133
- if image2:
134
- url2 = upload_image_to_hosting(image2)
135
- image_urls.append(url2)
136
 
137
- # Run the model (using a placeholder model name - replace with actual)
138
- # Note: "google/nano-banana" doesn't exist - replace with actual model
139
  output = replicate.run(
140
  "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
141
  input={
142
- "prompt": prompt,
143
- "image": url1 if len(image_urls) == 1 else None,
144
  "width": 1024,
145
- "height": 1024
 
146
  }
147
  )
148
 
149
- if output is None:
150
- return None, "No output received", None
151
-
152
- # Get the generated image
153
- img = None
154
-
155
- # Handle different output formats
156
- if isinstance(output, list) and len(output) > 0:
157
- output_url = output[0]
158
- elif isinstance(output, str):
159
- output_url = output
160
- else:
161
- output_url = str(output)
162
-
163
- if output_url:
164
- response = requests.get(output_url, timeout=30)
165
  if response.status_code == 200:
166
  img = Image.open(BytesIO(response.content))
 
167
 
168
- if img:
169
- return img, "✨ Image generated successfully!", img
170
- else:
171
- return None, "Could not process output", None
172
 
173
  except Exception as e:
174
- return None, f"Error: {str(e)[:200]}", None
 
 
 
175
 
176
  # ===========================
177
- # Video Generation Functions (Simplified)
178
  # ===========================
179
 
180
  def resize_image_for_video(image: Image.Image) -> Image.Image:
181
  """Resize image for video generation"""
182
- target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT
183
- width, height = image.size
184
- in_aspect = width / height
185
-
186
- if in_aspect > target_aspect:
187
- new_width = round(height * target_aspect)
188
- left = (width - new_width) // 2
189
- image = image.crop((left, 0, left + new_width, height))
190
- else:
191
- new_height = round(width / target_aspect)
192
- top = (height - new_height) // 2
193
- image = image.crop((0, top, width, top + new_height))
194
 
195
- return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
 
 
196
 
197
- def generate_video(
 
 
198
  input_image,
199
  prompt,
200
- steps=4,
201
  negative_prompt=default_negative_prompt,
202
- duration_seconds=1.5,
203
- guidance_scale=1,
204
- guidance_scale_2=1,
205
  seed=42,
206
  randomize_seed=False,
207
  ):
208
- """Generate a video from an input image (simplified version)"""
209
- if input_image is None:
210
- raise gr.Error("Please generate or upload an image first.")
211
 
212
- if not TORCH_AVAILABLE:
213
- raise gr.Error("Video generation is not available. PyTorch is not installed.")
214
 
215
  try:
216
- # Import dependencies
217
- video_deps = lazy_import_video_dependencies()
218
- if not all(video_deps):
219
- raise gr.Error("Video generation dependencies are not available.")
220
-
221
- WanImageToVideoPipeline, WanTransformer3DModel, export_to_video = video_deps
222
-
223
- global video_pipe
224
-
225
- # Simple initialization without complex optimizations
226
- if video_pipe is None:
227
- print("Initializing video pipeline (simplified)...")
228
-
229
- # Clear GPU memory first
230
- if TORCH_AVAILABLE:
231
- torch.cuda.empty_cache()
232
- gc.collect()
233
-
234
- # Basic pipeline loading
235
- try:
236
- video_pipe = WanImageToVideoPipeline.from_pretrained(
237
- VIDEO_MODEL_ID,
238
- torch_dtype=torch.float16 if TORCH_AVAILABLE else None,
239
- low_cpu_mem_usage=True,
240
- device_map="auto"
241
- )
242
- print("Video pipeline loaded")
243
- except Exception as e:
244
- print(f"Failed to load video pipeline: {e}")
245
- raise gr.Error("Could not load video model. Please try again later.")
246
-
247
- # Prepare video generation
248
- num_frames = min(17, int(round(duration_seconds * FIXED_FPS))) # Limit frames
249
- num_frames = ((num_frames - 1) // 4) * 4 + 1 # Ensure divisible by 4
250
-
251
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
252
 
253
  # Resize image
254
  resized_image = resize_image_for_video(input_image)
255
 
256
- # Generate video with minimal settings
257
- print(f"Generating {num_frames} frames...")
258
-
259
- if TORCH_AVAILABLE:
260
- generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(current_seed)
261
- else:
262
- generator = None
263
 
264
- output_frames_list = video_pipe(
265
- image=resized_image,
266
- prompt=prompt,
267
- negative_prompt=negative_prompt,
268
- height=LANDSCAPE_HEIGHT,
269
- width=LANDSCAPE_WIDTH,
270
- num_frames=num_frames,
271
- guidance_scale=float(guidance_scale),
272
- num_inference_steps=int(steps),
273
- generator=generator,
274
- ).frames[0]
275
 
276
- # Save video
277
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
278
- video_path = tmpfile.name
 
 
 
 
 
 
 
 
 
 
279
 
280
- export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- return video_path, current_seed, f"🎬 Video generated! ({num_frames} frames)"
283
 
284
  except Exception as e:
285
- if TORCH_AVAILABLE:
 
286
  torch.cuda.empty_cache()
287
  gc.collect()
288
- error_msg = str(e)[:200]
289
- if "out of memory" in error_msg.lower():
290
- return None, seed, "GPU memory exceeded. Try reducing duration and steps."
291
- return None, seed, f"Error: {error_msg}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  # ===========================
294
- # Simple CSS
295
  # ===========================
296
 
297
  css = """
298
  .gradio-container {
299
- max-width: 1200px;
300
- margin: 0 auto;
301
  }
302
  .header-container {
303
- background: linear-gradient(135deg, #ffd93d 0%, #ffb347 100%);
304
  padding: 2rem;
305
- border-radius: 12px;
306
  margin-bottom: 2rem;
307
  text-align: center;
308
  }
@@ -310,197 +248,228 @@ css = """
310
  font-size: 2.5rem;
311
  font-weight: bold;
312
  color: #2d3436;
313
- margin: 0;
314
  }
315
  .subtitle {
316
  color: #2d3436;
317
- font-size: 1rem;
318
  margin-top: 0.5rem;
319
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  """
321
 
322
  # ===========================
323
- # Gradio Interface (Simplified)
324
  # ===========================
325
 
326
- def create_demo():
327
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
328
- # Shared state
329
- generated_image_state = gr.State(None)
330
-
331
- gr.HTML("""
332
- <div class="header-container">
333
- <h1 class="logo-text">🍌 Nano Banana + Video</h1>
334
- <p class="subtitle">AI-Powered Image Generation with Video Creation</p>
335
- </div>
336
- """)
337
-
338
- with gr.Tabs():
339
- # Tab 1: Image Generation
340
- with gr.TabItem("🎨 Step 1: Generate Image"):
341
- with gr.Row():
342
- with gr.Column():
343
- style_prompt = gr.Textbox(
344
- label="Style Description",
345
- placeholder="Describe your style...",
346
- lines=3,
347
- value="A beautiful landscape in anime style"
348
- )
349
-
 
 
 
 
 
 
350
  image1 = gr.Image(
351
  label="Reference Image (Optional)",
352
- type="pil"
 
353
  )
354
-
355
  image2 = gr.Image(
356
- label="Secondary Image (Optional)",
357
- type="pil"
358
- )
359
-
360
- generate_img_btn = gr.Button(
361
- "Generate Image ✨",
362
- variant="primary"
363
  )
364
 
365
- with gr.Column():
366
- output_image = gr.Image(
367
- label="Generated Result",
368
- type="pil"
369
- )
370
-
371
- img_status = gr.Textbox(
372
- label="Status",
373
- interactive=False,
374
- value="Ready..."
375
- )
376
-
377
- send_to_video_btn = gr.Button(
378
- "Send to Video Generation →",
379
- variant="secondary",
380
- visible=False
381
- )
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- # Tab 2: Video Generation
384
- with gr.TabItem("🎬 Step 2: Generate Video"):
385
- with gr.Row():
386
- with gr.Column():
387
- video_input_image = gr.Image(
388
- type="pil",
389
- label="Input Image"
390
- )
391
-
392
- video_prompt = gr.Textbox(
393
- label="Animation Prompt",
394
- value=default_prompt_i2v
395
- )
396
-
 
397
  duration_input = gr.Slider(
398
- minimum=0.5,
399
- maximum=2.0,
400
  step=0.5,
401
- value=1.0,
402
  label="Duration (seconds)"
403
  )
404
 
405
  steps_slider = gr.Slider(
406
- minimum=1,
407
- maximum=8,
408
- step=1,
409
- value=4,
410
- label="Inference Steps"
411
- )
412
-
413
- generate_video_btn = gr.Button(
414
- "Generate Video 🎬",
415
- variant="primary"
416
  )
417
 
418
- with gr.Column():
419
- video_output = gr.Video(
420
- label="Generated Video",
421
- autoplay=True
 
 
 
422
  )
423
 
424
- video_status = gr.Textbox(
425
- label="Status",
426
- interactive=False,
427
- value="Ready..."
428
  )
429
-
430
- # Event Handlers
431
- def on_image_generated(prompt, img1, img2):
432
- img, status, state_img = process_images(prompt, img1, img2)
433
- if img:
434
- return img, status, state_img, gr.update(visible=True)
435
- return img, status, state_img, gr.update(visible=False)
436
-
437
- def send_image_to_video(img):
438
- if img:
439
- return img, "Image loaded!"
440
- return None, "No image to send."
441
-
442
- # Wire up events
443
- generate_img_btn.click(
444
- fn=on_image_generated,
445
- inputs=[style_prompt, image1, image2],
446
- outputs=[output_image, img_status, generated_image_state, send_to_video_btn]
447
- )
448
-
449
- send_to_video_btn.click(
450
- fn=send_image_to_video,
451
- inputs=[generated_image_state],
452
- outputs=[video_input_image, video_status]
453
- )
454
-
455
- # Simplified video generation
456
- def generate_video_wrapper(img, prompt, duration, steps):
457
- if not TORCH_AVAILABLE:
458
- return None, "Video generation requires PyTorch. Please install it first."
459
-
460
- try:
461
- video_path, seed, status = generate_video(
462
- img, prompt, steps=steps, duration_seconds=duration
463
- )
464
- return video_path, status
465
- except Exception as e:
466
- return None, f"Error: {str(e)[:100]}"
467
-
468
- generate_video_btn.click(
469
- fn=generate_video_wrapper,
470
- inputs=[video_input_image, video_prompt, duration_input, steps_slider],
471
- outputs=[video_output, video_status]
472
- )
473
-
474
- return demo
475
-
476
- # ===========================
477
- # Main Launch
478
- # ===========================
479
-
480
- if __name__ == "__main__":
481
- print("=" * 50)
482
- print("Starting Nano Banana + Video Application")
483
- print("=" * 50)
484
 
485
- # Check environment
486
- if not os.getenv('REPLICATE_API_TOKEN'):
487
- print("Warning: REPLICATE_API_TOKEN not set. Image generation may not work.")
 
 
 
488
 
489
- if not TORCH_AVAILABLE:
490
- print("Warning: PyTorch not available. Video generation will be disabled.")
491
- print("To enable video generation, install PyTorch: pip install torch")
 
492
 
493
- try:
494
- # Create and launch demo
495
- demo = create_demo()
496
-
497
- demo.launch(
498
- share=False, # Set to True if you want a public link
499
- server_name="0.0.0.0",
500
- server_port=7860,
501
- show_error=True,
502
- debug=False # Set to True for debugging
503
- )
504
- except Exception as e:
505
- print(f"Failed to launch application: {e}")
506
- print("Please check your environment and dependencies.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import time
8
  import tempfile
9
  import base64
10
+ import spaces
11
+ import torch
12
  import numpy as np
13
  import random
14
  import gc
15
 
 
 
 
 
 
 
 
 
16
  # ===========================
17
  # Configuration
18
  # ===========================
 
21
  os.environ['REPLICATE_API_TOKEN'] = os.getenv('REPLICATE_API_TOKEN')
22
 
23
  # Video Model Configuration
24
+ VIDEO_MODEL_ID = "cjwbw/videocrafter2:02e509c789964be7d70de8d8fef3a6dd18f160b37272bcccc742d5adabb9f38f" # Using public model
25
+ LANDSCAPE_WIDTH = 512 # Reduced for stability
26
+ LANDSCAPE_HEIGHT = 320 # Reduced for stability
27
  MAX_SEED = np.iinfo(np.int32).max
28
+ FIXED_FPS = 8 # Reduced FPS
29
  MIN_FRAMES_MODEL = 8
30
+ MAX_FRAMES_MODEL = 32 # Reduced max frames
 
 
31
 
32
+ default_prompt_i2v = "make this image come alive, smooth animation"
33
+ default_negative_prompt = "static, still, blurry, low quality"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # ===========================
36
  # Image Processing Functions
37
  # ===========================
38
 
39
  def upload_image_to_hosting(image):
40
+ """Upload image to hosting service"""
 
41
  try:
42
  buffered = BytesIO()
43
  image.save(buffered, format="PNG")
 
50
  'key': '6d207e02198a847aa98d0a2a901485a5',
51
  'image': img_base64,
52
  },
53
+ timeout=30
54
  )
55
 
56
  if response.status_code == 200:
 
58
  if data.get('success'):
59
  return data['data']['url']
60
  except Exception as e:
61
+ print(f"Upload failed: {e}")
62
 
63
+ # Fallback to base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  buffered = BytesIO()
65
  image.save(buffered, format="PNG")
66
  buffered.seek(0)
 
68
  return f"data:image/png;base64,{img_base64}"
69
 
70
  def process_images(prompt, image1, image2=None):
71
+ """Process images using Replicate API"""
72
  if not image1:
73
  return None, "Please upload at least one image", None
74
 
75
  if not os.getenv('REPLICATE_API_TOKEN'):
76
+ return None, "Please set REPLICATE_API_TOKEN in Space settings", None
77
 
78
  try:
79
+ # Upload image
 
 
80
  url1 = upload_image_to_hosting(image1)
 
 
 
 
 
81
 
82
+ # Use SDXL for image generation/editing
 
83
  output = replicate.run(
84
  "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
85
  input={
86
+ "prompt": prompt + ", high quality, detailed",
87
+ "negative_prompt": "low quality, blurry, distorted",
88
  "width": 1024,
89
+ "height": 1024,
90
+ "num_inference_steps": 25
91
  }
92
  )
93
 
94
+ if output and isinstance(output, list) and len(output) > 0:
95
+ img_url = output[0]
96
+ response = requests.get(img_url, timeout=30)
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if response.status_code == 200:
98
  img = Image.open(BytesIO(response.content))
99
+ return img, "✨ Image generated successfully!", img
100
 
101
+ return None, "Could not process output", None
 
 
 
102
 
103
  except Exception as e:
104
+ error_msg = str(e)
105
+ if "trial" in error_msg.lower():
106
+ return None, "Replicate API limit reached. Please try again later.", None
107
+ return None, f"Error: {error_msg[:200]}", None
108
 
109
  # ===========================
110
+ # Video Generation Functions
111
  # ===========================
112
 
113
  def resize_image_for_video(image: Image.Image) -> Image.Image:
114
  """Resize image for video generation"""
115
+ # Convert RGBA to RGB if necessary
116
+ if image.mode == 'RGBA':
117
+ background = Image.new('RGB', image.size, (255, 255, 255))
118
+ background.paste(image, mask=image.split()[3])
119
+ image = background
 
 
 
 
 
 
 
120
 
121
+ # Resize to target dimensions
122
+ image = image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
123
+ return image
124
 
125
+ # GPU function with proper decorator
126
+ @spaces.GPU(duration=60)
127
+ def generate_video_gpu(
128
  input_image,
129
  prompt,
130
+ steps=25,
131
  negative_prompt=default_negative_prompt,
132
+ duration_seconds=2.0,
 
 
133
  seed=42,
134
  randomize_seed=False,
135
  ):
136
+ """Generate video using Replicate API with GPU"""
 
 
137
 
138
+ if input_image is None:
139
+ return None, seed, "Please provide an input image"
140
 
141
  try:
142
+ # Clear GPU memory
143
+ if torch.cuda.is_available():
144
+ torch.cuda.empty_cache()
145
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  # Resize image
148
  resized_image = resize_image_for_video(input_image)
149
 
150
+ # Save resized image temporarily
151
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
152
+ resized_image.save(tmp_img.name)
153
+
154
+ # Upload to hosting
155
+ img_url = upload_image_to_hosting(resized_image)
 
156
 
157
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # Use Replicate for video generation
160
+ print("Generating video with Replicate...")
161
+ output = replicate.run(
162
+ VIDEO_MODEL_ID,
163
+ input={
164
+ "prompt": prompt,
165
+ "image": img_url,
166
+ "steps": int(steps),
167
+ "fps": FIXED_FPS,
168
+ "seconds": min(duration_seconds, 3), # Limit to 3 seconds
169
+ "seed": current_seed
170
+ }
171
+ )
172
 
173
+ if output:
174
+ # Download video
175
+ if isinstance(output, str):
176
+ video_url = output
177
+ elif hasattr(output, 'url'):
178
+ video_url = output.url()
179
+ else:
180
+ video_url = str(output)
181
+
182
+ response = requests.get(video_url, timeout=60)
183
+ if response.status_code == 200:
184
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_video:
185
+ tmp_video.write(response.content)
186
+ return tmp_video.name, current_seed, "🎬 Video generated successfully!"
187
 
188
+ return None, seed, "Failed to generate video"
189
 
190
  except Exception as e:
191
+ error_msg = str(e)
192
+ if "out of memory" in error_msg.lower():
193
  torch.cuda.empty_cache()
194
  gc.collect()
195
+ return None, seed, "GPU memory exceeded. Try reducing duration."
196
+ return None, seed, f"Error: {error_msg[:200]}"
197
+
198
+ # Wrapper function for video generation
199
+ def generate_video(
200
+ input_image,
201
+ prompt,
202
+ steps=25,
203
+ negative_prompt=default_negative_prompt,
204
+ duration_seconds=2.0,
205
+ seed=42,
206
+ randomize_seed=False,
207
+ ):
208
+ """Wrapper function that calls the GPU function"""
209
+ if not os.getenv('REPLICATE_API_TOKEN'):
210
+ return None, seed, "Please set REPLICATE_API_TOKEN in Space settings"
211
+
212
+ return generate_video_gpu(
213
+ input_image,
214
+ prompt,
215
+ steps,
216
+ negative_prompt,
217
+ duration_seconds,
218
+ seed,
219
+ randomize_seed
220
+ )
221
+
222
+ # ===========================
223
+ # Simple dummy GPU function for startup
224
+ # ===========================
225
+
226
+ @spaces.GPU(duration=1)
227
+ def dummy_gpu_function():
228
+ """Dummy function to satisfy Spaces GPU requirement"""
229
+ return "GPU initialized"
230
 
231
  # ===========================
232
+ # CSS Styling
233
  # ===========================
234
 
235
  css = """
236
  .gradio-container {
237
+ max-width: 1200px !important;
238
+ margin: 0 auto !important;
239
  }
240
  .header-container {
241
+ background: linear-gradient(135deg, #ffd93d, #ffb347);
242
  padding: 2rem;
243
+ border-radius: 15px;
244
  margin-bottom: 2rem;
245
  text-align: center;
246
  }
 
248
  font-size: 2.5rem;
249
  font-weight: bold;
250
  color: #2d3436;
 
251
  }
252
  .subtitle {
253
  color: #2d3436;
254
+ font-size: 1.1rem;
255
  margin-top: 0.5rem;
256
  }
257
+ .gr-button {
258
+ font-size: 1rem !important;
259
+ padding: 12px 24px !important;
260
+ }
261
+ .gr-button-primary {
262
+ background: linear-gradient(135deg, #ffd93d, #ffb347) !important;
263
+ border: none !important;
264
+ }
265
+ .gr-button-secondary {
266
+ background: linear-gradient(135deg, #667eea, #764ba2) !important;
267
+ color: white !important;
268
+ border: none !important;
269
+ }
270
  """
271
 
272
  # ===========================
273
+ # Gradio Interface
274
  # ===========================
275
 
276
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
277
+ # Initialize GPU on startup
278
+ startup_status = gr.State(dummy_gpu_function())
279
+
280
+ # Shared state
281
+ generated_image_state = gr.State(None)
282
+
283
+ gr.HTML("""
284
+ <div class="header-container">
285
+ <h1 class="logo-text">🍌 Nano Banana + Video</h1>
286
+ <p class="subtitle">AI Image Generation with Video Creation</p>
287
+ <p style="color: #636e72; font-size: 0.9rem; margin-top: 10px;">
288
+ ⚠️ Note: Add REPLICATE_API_TOKEN in Space Settings > Repository secrets
289
+ </p>
290
+ </div>
291
+ """)
292
+
293
+ with gr.Tabs():
294
+ # Tab 1: Image Generation
295
+ with gr.TabItem("🎨 Step 1: Generate Image"):
296
+ with gr.Row():
297
+ with gr.Column(scale=1):
298
+ style_prompt = gr.Textbox(
299
+ label="Image Description",
300
+ placeholder="Describe what you want to create...",
301
+ lines=3,
302
+ value="A beautiful fantasy landscape with mountains and a river, studio ghibli style"
303
+ )
304
+
305
+ with gr.Row():
306
  image1 = gr.Image(
307
  label="Reference Image (Optional)",
308
+ type="pil",
309
+ height=200
310
  )
 
311
  image2 = gr.Image(
312
+ label="Style Reference (Optional)",
313
+ type="pil",
314
+ height=200
 
 
 
 
315
  )
316
 
317
+ generate_img_btn = gr.Button(
318
+ "🎨 Generate Image",
319
+ variant="primary",
320
+ size="lg"
321
+ )
322
+
323
+ with gr.Column(scale=1):
324
+ output_image = gr.Image(
325
+ label="Generated Result",
326
+ type="pil",
327
+ height=400
328
+ )
329
+
330
+ img_status = gr.Textbox(
331
+ label="Status",
332
+ interactive=False,
333
+ value="Ready to generate..."
334
+ )
335
+
336
+ send_to_video_btn = gr.Button(
337
+ "➡️ Send to Video Generation",
338
+ variant="secondary",
339
+ visible=False
340
+ )
341
+
342
+ # Tab 2: Video Generation
343
+ with gr.TabItem("🎬 Step 2: Generate Video"):
344
+ gr.Markdown("### Transform your image into a video")
345
 
346
+ with gr.Row():
347
+ with gr.Column(scale=1):
348
+ video_input_image = gr.Image(
349
+ type="pil",
350
+ label="Input Image",
351
+ height=300
352
+ )
353
+
354
+ video_prompt = gr.Textbox(
355
+ label="Animation Description",
356
+ value=default_prompt_i2v,
357
+ lines=2
358
+ )
359
+
360
+ with gr.Row():
361
  duration_input = gr.Slider(
362
+ minimum=1.0,
363
+ maximum=3.0,
364
  step=0.5,
365
+ value=2.0,
366
  label="Duration (seconds)"
367
  )
368
 
369
  steps_slider = gr.Slider(
370
+ minimum=10,
371
+ maximum=50,
372
+ step=5,
373
+ value=25,
374
+ label="Quality Steps"
 
 
 
 
 
375
  )
376
 
377
+ with gr.Row():
378
+ video_seed = gr.Slider(
379
+ label="Seed",
380
+ minimum=0,
381
+ maximum=MAX_SEED,
382
+ step=1,
383
+ value=42
384
  )
385
 
386
+ randomize_seed = gr.Checkbox(
387
+ label="Random seed",
388
+ value=True
 
389
  )
390
+
391
+ video_negative_prompt = gr.Textbox(
392
+ label="Negative Prompt",
393
+ value=default_negative_prompt,
394
+ lines=2
395
+ )
396
+
397
+ generate_video_btn = gr.Button(
398
+ "🎬 Generate Video",
399
+ variant="primary",
400
+ size="lg"
401
+ )
402
+
403
+ with gr.Column(scale=1):
404
+ video_output = gr.Video(
405
+ label="Generated Video",
406
+ autoplay=True,
407
+ height=400
408
+ )
409
+
410
+ video_status = gr.Textbox(
411
+ label="Status",
412
+ interactive=False,
413
+ value="Ready to generate video..."
414
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
+ # Event Handlers
417
+ def on_image_generated(prompt, img1, img2):
418
+ img, status, state_img = process_images(prompt, img1, img2)
419
+ if img:
420
+ return img, status, state_img, gr.update(visible=True)
421
+ return None, status, None, gr.update(visible=False)
422
 
423
+ def send_image_to_video(img):
424
+ if img:
425
+ return img, "Image loaded! Ready to generate video."
426
+ return None, "No image to send."
427
 
428
+ # Connect events
429
+ generate_img_btn.click(
430
+ fn=on_image_generated,
431
+ inputs=[style_prompt, image1, image2],
432
+ outputs=[output_image, img_status, generated_image_state, send_to_video_btn]
433
+ )
434
+
435
+ send_to_video_btn.click(
436
+ fn=send_image_to_video,
437
+ inputs=[generated_image_state],
438
+ outputs=[video_input_image, video_status]
439
+ )
440
+
441
+ generate_video_btn.click(
442
+ fn=generate_video,
443
+ inputs=[
444
+ video_input_image,
445
+ video_prompt,
446
+ steps_slider,
447
+ video_negative_prompt,
448
+ duration_input,
449
+ video_seed,
450
+ randomize_seed
451
+ ],
452
+ outputs=[video_output, video_seed, video_status]
453
+ )
454
+
455
+ # Examples
456
+ gr.Examples(
457
+ examples=[
458
+ ["A majestic castle on a hilltop at sunset, fantasy art style"],
459
+ ["Cute robot in a flower garden, pixar animation style"],
460
+ ["Northern lights over a frozen lake, photorealistic"],
461
+ ["Ancient temple in a jungle, mysterious atmosphere"],
462
+ ],
463
+ inputs=[style_prompt],
464
+ label="Example Prompts"
465
+ )
466
+
467
+ # Launch the app
468
+ if __name__ == "__main__":
469
+ print("Starting Nano Banana + Video app...")
470
+ print("Make sure to set REPLICATE_API_TOKEN in your Space settings!")
471
+
472
+ demo.launch(
473
+ share=False,
474
+ show_error=True
475
+ )