Spaces:
Running
Running
| """Gradio demo Space for Trainingless โ see plan.md for details.""" | |
| from __future__ import annotations | |
| import base64 | |
| import io | |
| import os | |
| import time | |
| import uuid | |
| from datetime import datetime | |
| from typing import Tuple, Optional | |
| import requests | |
| from dotenv import load_dotenv | |
| from PIL import Image | |
| import gradio as gr | |
| from supabase import create_client, Client | |
| # ----------------------------------------------------------------------------- | |
| # Environment & Supabase setup | |
| # ----------------------------------------------------------------------------- | |
| # Load .env file *once* when running locally. The HF Spaces runtime injects the | |
| # same names via its Secrets mechanism, so calling load_dotenv() is harmless. | |
| load_dotenv() | |
| SUPABASE_URL: str = os.getenv("SUPABASE_URL", "") | |
| # Use a *secret* (server-only) key so the backend bypasses RLS. | |
| SUPABASE_SECRET_KEY: str = os.getenv("SUPABASE_SECRET_KEY", "") | |
| # (Optional) You can override which Edge Function gets called. | |
| SUPABASE_FUNCTION_URL: str = os.getenv( | |
| "SUPABASE_FUNCTION_URL", f"{SUPABASE_URL}/functions/v1/process-image" | |
| ) | |
| # Storage bucket for uploads. Must be *public*. | |
| UPLOAD_BUCKET = os.getenv("SUPABASE_UPLOAD_BUCKET", "images") | |
| REQUEST_TIMEOUT = int(os.getenv("SUPABASE_FN_TIMEOUT", "240")) # seconds | |
| # Available model workflows recognised by edge function | |
| WORKFLOW_CHOICES = [ | |
| "eyewear", | |
| "footwear", | |
| "dress", | |
| ] | |
| if not SUPABASE_URL or not SUPABASE_SECRET_KEY: | |
| raise RuntimeError( | |
| "SUPABASE_URL and SUPABASE_SECRET_KEY must be set in the environment." | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Supabase client โ server-side: authenticate with secret key (bypasses RLS) | |
| # ----------------------------------------------------------------------------- | |
| supabase: Client = create_client(SUPABASE_URL, SUPABASE_SECRET_KEY) | |
| # Ensure the uploads bucket exists (idempotent). This requires service role *once*; | |
| try: | |
| buckets = supabase.storage.list_buckets() # type: ignore[attr-defined] | |
| bucket_names = {b["name"] for b in buckets} if isinstance(buckets, list) else set() | |
| if UPLOAD_BUCKET not in bucket_names: | |
| # Attempt to create bucket (will fail w/ anon key โ inform user to create) | |
| try: | |
| supabase.storage.create_bucket( | |
| UPLOAD_BUCKET, | |
| public=True, | |
| ) | |
| print(f"[startup] Created bucket '{UPLOAD_BUCKET}'.") | |
| except Exception as create_exc: # noqa: BLE001 | |
| print(f"[startup] Could not create bucket '{UPLOAD_BUCKET}': {create_exc!r}") | |
| except Exception as exc: # noqa: BLE001 | |
| # Non-fatal. The bucket probably already exists or we don't have perms. | |
| print(f"[startup] Bucket check/create raised {exc!r}. Continuingโฆ") | |
| # ----------------------------------------------------------------------------- | |
| # Helper functions | |
| # ----------------------------------------------------------------------------- | |
| def pil_to_bytes(img: Image.Image) -> bytes: | |
| """Convert PIL Image to PNG bytes.""" | |
| with io.BytesIO() as buffer: | |
| img.save(buffer, format="PNG") | |
| return buffer.getvalue() | |
| def upload_image_to_supabase(img: Image.Image, path: str) -> str: | |
| """Upload image under `UPLOAD_BUCKET/path` and return **public URL**.""" | |
| data = pil_to_bytes(img) | |
| # Overwrite if exists | |
| supabase.storage.from_(UPLOAD_BUCKET).upload( | |
| path, | |
| data, | |
| {"content-type": "image/png", "upsert": "true"}, # upsert must be string | |
| ) # type: ignore[attr-defined] | |
| public_url = ( | |
| f"{SUPABASE_URL}/storage/v1/object/public/{UPLOAD_BUCKET}/{path}" | |
| ) | |
| return public_url | |
| def wait_for_job_completion(job_id: str, timeout_s: int = 600) -> Optional[str]: | |
| """Subscribe to the single row via Realtime. Fallback to polling every 5 s.""" | |
| # First try realtime subscription (non-blocking). If it errors, fall back. | |
| completed_image: Optional[str] = None | |
| did_subscribe = False | |
| try: | |
| # Docs: https://supabase.com/docs/reference/python/creating-channels | |
| channel = ( | |
| supabase.channel("job_channel") | |
| .on( | |
| "postgres_changes", | |
| { | |
| "event": "UPDATE", | |
| "schema": "public", | |
| "table": "processing_jobs", | |
| "filter": f"id=eq.{job_id}", | |
| }, | |
| lambda payload: _realtime_callback(payload, job_id), | |
| ) | |
| .subscribe() | |
| ) | |
| did_subscribe = True | |
| except Exception as exc: # noqa: BLE001 | |
| print(f"[wait] Realtime subscription failed โ will poll: {exc!r}") | |
| start = time.time() | |
| while time.time() - start < timeout_s: | |
| if _RESULT_CACHE.get(job_id): | |
| completed_image = _RESULT_CACHE.pop(job_id) | |
| break | |
| if not did_subscribe or (time.time() - start) % 5 == 0: | |
| # Poll once every ~5 s | |
| data = ( | |
| supabase.table("processing_jobs") | |
| .select("status,result_image_url") | |
| .eq("id", job_id) | |
| .single() | |
| .execute() | |
| ) | |
| if data.data and data.data["status"] == "completed": | |
| completed_image = data.data.get("result_image_url") | |
| break | |
| time.sleep(1) | |
| try: | |
| if did_subscribe: | |
| supabase.remove_channel(channel) | |
| except Exception: # noqa: PIE786, BLE001 | |
| pass | |
| return completed_image | |
| _RESULT_CACHE: dict[str, str] = {} | |
| def _realtime_callback(payload: dict, job_id: str) -> None: | |
| new = payload.get("new", {}) # type: ignore[index] | |
| if new.get("status") == "completed": | |
| _RESULT_CACHE[job_id] = new.get("result_image_url") | |
| MAX_PIXELS = 1_500_000 # 1.5 megapixels ceiling for each uploaded image | |
| def downscale_image(img: Image.Image, max_pixels: int = MAX_PIXELS) -> Image.Image: | |
| """Downscale *img* proportionally so that widthรheight โค *max_pixels*. | |
| If the image is already small enough, it is returned unchanged. | |
| """ | |
| w, h = img.size | |
| if w * h <= max_pixels: | |
| return img | |
| scale = (max_pixels / (w * h)) ** 0.5 # uniform scaling factor | |
| new_size = (max(1, int(w * scale)), max(1, int(h * scale))) | |
| return img.resize(new_size, Image.LANCZOS) | |
| def _public_storage_url(path: str) -> str: | |
| """Return a public (https) URL given an object *path* inside any bucket. | |
| If *path* already looks like a full URL, it is returned unchanged. | |
| """ | |
| if path.startswith("http://") or path.startswith("https://"): | |
| return path | |
| # Ensure no leading slash. | |
| return f"{SUPABASE_URL}/storage/v1/object/public/{path.lstrip('/')}" | |
| # ----------------------------------------------------------------------------- | |
| # Main generate function | |
| # ----------------------------------------------------------------------------- | |
| def generate(base_img: Image.Image, garment_img: Image.Image, workflow_choice: str) -> Image.Image: | |
| if base_img is None or garment_img is None: | |
| raise gr.Error("Please provide both images.") | |
| # 1. Persist both images to Supabase storage | |
| job_id = str(uuid.uuid4()) | |
| folder = f"user_uploads/gradio/{job_id}" | |
| base_filename = f"{uuid.uuid4().hex}.png" | |
| garment_filename = f"{uuid.uuid4().hex}.png" | |
| base_path = f"{folder}/{base_filename}" | |
| garment_path = f"{folder}/{garment_filename}" | |
| base_img = downscale_image(base_img) | |
| garment_img = downscale_image(garment_img) | |
| base_url = upload_image_to_supabase(base_img, base_path) | |
| garment_url = upload_image_to_supabase(garment_img, garment_path) | |
| # 2. Insert new row into processing_jobs (anon key, relies on open RLS) | |
| token_for_row = str(uuid.uuid4()) | |
| insert_payload = { | |
| "id": job_id, | |
| "status": "queued", | |
| "base_image_path": base_url, | |
| "garment_image_path": garment_url, | |
| "mask_image_path": base_url, | |
| "access_token": token_for_row, | |
| "created_at": datetime.utcnow().isoformat(), | |
| } | |
| supabase.table("processing_jobs").insert(insert_payload).execute() | |
| # 3. Trigger edge function | |
| workflow_choice = (workflow_choice or "eyewear").lower() | |
| if workflow_choice not in WORKFLOW_CHOICES: | |
| workflow_choice = "eyewear" | |
| fn_payload = { | |
| "baseImageUrl": base_url, | |
| "garmentImageUrl": garment_url, | |
| # ๐ hack: use garment as placeholder mask until proper mask provided | |
| "maskImageUrl": garment_url, | |
| "jobId": job_id, | |
| "workflowType": workflow_choice, | |
| } | |
| headers = { | |
| "Content-Type": "application/json", | |
| "apikey": SUPABASE_SECRET_KEY, | |
| "Authorization": f"Bearer {SUPABASE_SECRET_KEY}", | |
| } | |
| resp = requests.post( | |
| SUPABASE_FUNCTION_URL, | |
| json=fn_payload, | |
| headers=headers, | |
| timeout=REQUEST_TIMEOUT, | |
| ) | |
| if not resp.ok: | |
| raise gr.Error(f"Backend error: {resp.text}") | |
| # 4. Wait for completion via realtime (or polling fallback) | |
| result = wait_for_job_completion(job_id) | |
| if not result: | |
| raise gr.Error("Timed out waiting for job to finish.") | |
| # Result may be base64 data URI or http URL; normalise. | |
| if result.startswith("data:image"): | |
| header, b64 = result.split(",", 1) | |
| img_bytes = base64.b64decode(b64) | |
| result_img = Image.open(io.BytesIO(img_bytes)).convert("RGBA") | |
| else: | |
| result_url = _public_storage_url(result) | |
| resp_img = requests.get(result_url, timeout=30) | |
| resp_img.raise_for_status() | |
| result_img = Image.open(io.BytesIO(resp_img.content)).convert("RGBA") | |
| return result_img | |
| # ----------------------------------------------------------------------------- | |
| # Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| description = "Upload a person photo (Base) and a product image. Select between Eyewear, Footwear, or Full-Body Garments to switch between the three available models. Click ๐ **Generate** to try on a product." # noqa: E501 | |
| with gr.Blocks(title="YOURMIRROR.IO - SM4LL-VTON Demo") as demo: | |
| # Header | |
| gr.Markdown("# SM4LL-VTON PRE-RELEASE DEMO | YOURMIRROR.IO | Virtual Try-On") | |
| gr.Markdown(description) | |
| IMG_SIZE = 256 | |
| with gr.Row(): | |
| # Left column โ inputs stacked vertically | |
| with gr.Column(scale=1): | |
| base_in = gr.Image( | |
| label="Base Image", | |
| type="pil", | |
| height=IMG_SIZE, | |
| width=IMG_SIZE, | |
| ) | |
| garment_in = gr.Image( | |
| label="Product Image", | |
| type="pil", | |
| height=IMG_SIZE, | |
| width=IMG_SIZE, | |
| ) | |
| # Centre column โ result image (larger) | |
| with gr.Column(scale=2): | |
| result_out = gr.Image( | |
| label="Result", | |
| height=512, | |
| width=512, | |
| ) | |
| # Right column โ workflow selector and Generate button | |
| with gr.Column(scale=1, elem_classes="control-column"): | |
| workflow_selector = gr.Radio( | |
| choices=WORKFLOW_CHOICES, | |
| value="eyewear", | |
| label="Model", | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| # Wire up interaction | |
| generate_btn.click( | |
| generate, | |
| inputs=[base_in, garment_in, workflow_selector], | |
| outputs=result_out, | |
| ) | |
| # Run app if executed directly (e.g. `python app.py`). HF Spaces launches via | |
| # `python app.py` automatically if it finds `app.py` at repo root, but our file | |
| # lives in a sub-folder, so we keep the guard. | |
| if __name__ == "__main__": | |
| demo.launch() |