risunobushi commited on
Commit
f55be02
Β·
1 Parent(s): eb48e62

Add optional mask parameter to generate function

Browse files

- Add optional mask_img parameter to generate function for ComfyUI integration
- Implement smart fallback: use provided mask or base image (instead of garment)
- Add logging for mask selection debugging
- Update API payload to use proper mask URL or base image fallback

Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -277,7 +277,13 @@ def _public_storage_url(path: str) -> str:
277
  # Main generate function
278
  # -----------------------------------------------------------------------------
279
 
280
- def generate(base_img: Image.Image, garment_img: Image.Image, workflow_choice: str, request: gr.Request) -> Image.Image:
 
 
 
 
 
 
281
  if base_img is None or garment_img is None:
282
  raise gr.Error("Please provide both images.")
283
 
@@ -299,6 +305,18 @@ def generate(base_img: Image.Image, garment_img: Image.Image, workflow_choice: s
299
 
300
  base_url = upload_image_to_supabase(base_img, base_path)
301
  garment_url = upload_image_to_supabase(garment_img, garment_path)
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  # 2. Insert new row into processing_jobs (anon key, relies on open RLS)
304
  token_for_row = str(uuid.uuid4())
@@ -307,7 +325,7 @@ def generate(base_img: Image.Image, garment_img: Image.Image, workflow_choice: s
307
  "status": "queued",
308
  "base_image_path": base_url,
309
  "garment_image_path": garment_url,
310
- "mask_image_path": base_url,
311
  "access_token": token_for_row,
312
  "created_at": datetime.utcnow().isoformat(),
313
  }
@@ -321,11 +339,17 @@ def generate(base_img: Image.Image, garment_img: Image.Image, workflow_choice: s
321
  fn_payload = {
322
  "baseImageUrl": base_url,
323
  "garmentImageUrl": garment_url,
324
- # πŸ‘‰ hack: use garment as placeholder mask until proper mask provided
325
- "maskImageUrl": garment_url,
326
  "jobId": job_id,
327
  "workflowType": workflow_choice,
328
  }
 
 
 
 
 
 
329
  headers = {
330
  "Content-Type": "application/json",
331
  "apikey": SUPABASE_SECRET_KEY,
 
277
  # Main generate function
278
  # -----------------------------------------------------------------------------
279
 
280
+ def generate(
281
+ base_img: Image.Image,
282
+ garment_img: Image.Image,
283
+ workflow_choice: str,
284
+ request: gr.Request,
285
+ mask_img: Optional[Image.Image] = None # NEW: Optional mask parameter
286
+ ) -> Image.Image:
287
  if base_img is None or garment_img is None:
288
  raise gr.Error("Please provide both images.")
289
 
 
305
 
306
  base_url = upload_image_to_supabase(base_img, base_path)
307
  garment_url = upload_image_to_supabase(garment_img, garment_path)
308
+
309
+ # Handle optional mask image (if provided by ComfyUI or future web UI)
310
+ mask_url = None
311
+ if mask_img is not None:
312
+ print(f"[MASK] Processing user-provided mask image")
313
+ mask_filename = f"{uuid.uuid4().hex}.png"
314
+ mask_path = f"{folder}/{mask_filename}"
315
+ mask_img = downscale_image(mask_img)
316
+ mask_url = upload_image_to_supabase(mask_img, mask_path)
317
+ print(f"[MASK] Uploaded mask: {mask_url}")
318
+ else:
319
+ print(f"[MASK] No mask provided - will use base image fallback")
320
 
321
  # 2. Insert new row into processing_jobs (anon key, relies on open RLS)
322
  token_for_row = str(uuid.uuid4())
 
325
  "status": "queued",
326
  "base_image_path": base_url,
327
  "garment_image_path": garment_url,
328
+ "mask_image_path": mask_url if mask_url else base_url, # Track actual mask used
329
  "access_token": token_for_row,
330
  "created_at": datetime.utcnow().isoformat(),
331
  }
 
339
  fn_payload = {
340
  "baseImageUrl": base_url,
341
  "garmentImageUrl": garment_url,
342
+ # 🎭 Smart fallback: use provided mask OR base image (much better than garment!)
343
+ "maskImageUrl": mask_url if mask_url else base_url,
344
  "jobId": job_id,
345
  "workflowType": workflow_choice,
346
  }
347
+
348
+ # Log mask selection for debugging
349
+ if mask_url:
350
+ print(f"[API] Using user-provided mask: {mask_url}")
351
+ else:
352
+ print(f"[API] Using base image as mask fallback: {base_url}")
353
  headers = {
354
  "Content-Type": "application/json",
355
  "apikey": SUPABASE_SECRET_KEY,