sm4ll-VTON-Demo / app.py
risunobushi
add dress functionalities
f1c6fd2
raw
history blame
11.7 kB
"""Gradio demo Space for Trainingless โ€“ see plan.md for details."""
from __future__ import annotations
import base64
import io
import os
import time
import uuid
from datetime import datetime
from typing import Tuple, Optional
import requests
from dotenv import load_dotenv
from PIL import Image
import gradio as gr
from supabase import create_client, Client
# -----------------------------------------------------------------------------
# Environment & Supabase setup
# -----------------------------------------------------------------------------
# Load .env file *once* when running locally. The HF Spaces runtime injects the
# same names via its Secrets mechanism, so calling load_dotenv() is harmless.
load_dotenv()
SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
# Use a *secret* (server-only) key so the backend bypasses RLS.
SUPABASE_SECRET_KEY: str = os.getenv("SUPABASE_SECRET_KEY", "")
# (Optional) You can override which Edge Function gets called.
SUPABASE_FUNCTION_URL: str = os.getenv(
"SUPABASE_FUNCTION_URL", f"{SUPABASE_URL}/functions/v1/process-image"
)
# Storage bucket for uploads. Must be *public*.
UPLOAD_BUCKET = os.getenv("SUPABASE_UPLOAD_BUCKET", "images")
REQUEST_TIMEOUT = int(os.getenv("SUPABASE_FN_TIMEOUT", "240")) # seconds
# Available model workflows recognised by edge function
WORKFLOW_CHOICES = [
"eyewear",
"footwear",
"dress",
]
if not SUPABASE_URL or not SUPABASE_SECRET_KEY:
raise RuntimeError(
"SUPABASE_URL and SUPABASE_SECRET_KEY must be set in the environment."
)
# -----------------------------------------------------------------------------
# Supabase client โ€“ server-side: authenticate with secret key (bypasses RLS)
# -----------------------------------------------------------------------------
supabase: Client = create_client(SUPABASE_URL, SUPABASE_SECRET_KEY)
# Ensure the uploads bucket exists (idempotent). This requires service role *once*;
try:
buckets = supabase.storage.list_buckets() # type: ignore[attr-defined]
bucket_names = {b["name"] for b in buckets} if isinstance(buckets, list) else set()
if UPLOAD_BUCKET not in bucket_names:
# Attempt to create bucket (will fail w/ anon key โ€“ inform user to create)
try:
supabase.storage.create_bucket(
UPLOAD_BUCKET,
public=True,
)
print(f"[startup] Created bucket '{UPLOAD_BUCKET}'.")
except Exception as create_exc: # noqa: BLE001
print(f"[startup] Could not create bucket '{UPLOAD_BUCKET}': {create_exc!r}")
except Exception as exc: # noqa: BLE001
# Non-fatal. The bucket probably already exists or we don't have perms.
print(f"[startup] Bucket check/create raised {exc!r}. Continuingโ€ฆ")
# -----------------------------------------------------------------------------
# Helper functions
# -----------------------------------------------------------------------------
def pil_to_bytes(img: Image.Image) -> bytes:
"""Convert PIL Image to PNG bytes."""
with io.BytesIO() as buffer:
img.save(buffer, format="PNG")
return buffer.getvalue()
def upload_image_to_supabase(img: Image.Image, path: str) -> str:
"""Upload image under `UPLOAD_BUCKET/path` and return **public URL**."""
data = pil_to_bytes(img)
# Overwrite if exists
supabase.storage.from_(UPLOAD_BUCKET).upload(
path,
data,
{"content-type": "image/png", "upsert": "true"}, # upsert must be string
) # type: ignore[attr-defined]
public_url = (
f"{SUPABASE_URL}/storage/v1/object/public/{UPLOAD_BUCKET}/{path}"
)
return public_url
def wait_for_job_completion(job_id: str, timeout_s: int = 600) -> Optional[str]:
"""Subscribe to the single row via Realtime. Fallback to polling every 5 s."""
# First try realtime subscription (non-blocking). If it errors, fall back.
completed_image: Optional[str] = None
did_subscribe = False
try:
# Docs: https://supabase.com/docs/reference/python/creating-channels
channel = (
supabase.channel("job_channel")
.on(
"postgres_changes",
{
"event": "UPDATE",
"schema": "public",
"table": "processing_jobs",
"filter": f"id=eq.{job_id}",
},
lambda payload: _realtime_callback(payload, job_id),
)
.subscribe()
)
did_subscribe = True
except Exception as exc: # noqa: BLE001
print(f"[wait] Realtime subscription failed โ€“ will poll: {exc!r}")
start = time.time()
while time.time() - start < timeout_s:
if _RESULT_CACHE.get(job_id):
completed_image = _RESULT_CACHE.pop(job_id)
break
if not did_subscribe or (time.time() - start) % 5 == 0:
# Poll once every ~5 s
data = (
supabase.table("processing_jobs")
.select("status,result_image_url")
.eq("id", job_id)
.single()
.execute()
)
if data.data and data.data["status"] == "completed":
completed_image = data.data.get("result_image_url")
break
time.sleep(1)
try:
if did_subscribe:
supabase.remove_channel(channel)
except Exception: # noqa: PIE786, BLE001
pass
return completed_image
_RESULT_CACHE: dict[str, str] = {}
def _realtime_callback(payload: dict, job_id: str) -> None:
new = payload.get("new", {}) # type: ignore[index]
if new.get("status") == "completed":
_RESULT_CACHE[job_id] = new.get("result_image_url")
MAX_PIXELS = 1_500_000 # 1.5 megapixels ceiling for each uploaded image
def downscale_image(img: Image.Image, max_pixels: int = MAX_PIXELS) -> Image.Image:
"""Downscale *img* proportionally so that widthร—height โ‰ค *max_pixels*.
If the image is already small enough, it is returned unchanged.
"""
w, h = img.size
if w * h <= max_pixels:
return img
scale = (max_pixels / (w * h)) ** 0.5 # uniform scaling factor
new_size = (max(1, int(w * scale)), max(1, int(h * scale)))
return img.resize(new_size, Image.LANCZOS)
def _public_storage_url(path: str) -> str:
"""Return a public (https) URL given an object *path* inside any bucket.
If *path* already looks like a full URL, it is returned unchanged.
"""
if path.startswith("http://") or path.startswith("https://"):
return path
# Ensure no leading slash.
return f"{SUPABASE_URL}/storage/v1/object/public/{path.lstrip('/')}"
# -----------------------------------------------------------------------------
# Main generate function
# -----------------------------------------------------------------------------
def generate(base_img: Image.Image, garment_img: Image.Image, workflow_choice: str) -> Image.Image:
if base_img is None or garment_img is None:
raise gr.Error("Please provide both images.")
# 1. Persist both images to Supabase storage
job_id = str(uuid.uuid4())
folder = f"user_uploads/gradio/{job_id}"
base_filename = f"{uuid.uuid4().hex}.png"
garment_filename = f"{uuid.uuid4().hex}.png"
base_path = f"{folder}/{base_filename}"
garment_path = f"{folder}/{garment_filename}"
base_img = downscale_image(base_img)
garment_img = downscale_image(garment_img)
base_url = upload_image_to_supabase(base_img, base_path)
garment_url = upload_image_to_supabase(garment_img, garment_path)
# 2. Insert new row into processing_jobs (anon key, relies on open RLS)
token_for_row = str(uuid.uuid4())
insert_payload = {
"id": job_id,
"status": "queued",
"base_image_path": base_url,
"garment_image_path": garment_url,
"mask_image_path": base_url,
"access_token": token_for_row,
"created_at": datetime.utcnow().isoformat(),
}
supabase.table("processing_jobs").insert(insert_payload).execute()
# 3. Trigger edge function
workflow_choice = (workflow_choice or "eyewear").lower()
if workflow_choice not in WORKFLOW_CHOICES:
workflow_choice = "eyewear"
fn_payload = {
"baseImageUrl": base_url,
"garmentImageUrl": garment_url,
# ๐Ÿ‘‰ hack: use garment as placeholder mask until proper mask provided
"maskImageUrl": garment_url,
"jobId": job_id,
"workflowType": workflow_choice,
}
headers = {
"Content-Type": "application/json",
"apikey": SUPABASE_SECRET_KEY,
"Authorization": f"Bearer {SUPABASE_SECRET_KEY}",
}
resp = requests.post(
SUPABASE_FUNCTION_URL,
json=fn_payload,
headers=headers,
timeout=REQUEST_TIMEOUT,
)
if not resp.ok:
raise gr.Error(f"Backend error: {resp.text}")
# 4. Wait for completion via realtime (or polling fallback)
result = wait_for_job_completion(job_id)
if not result:
raise gr.Error("Timed out waiting for job to finish.")
# Result may be base64 data URI or http URL; normalise.
if result.startswith("data:image"):
header, b64 = result.split(",", 1)
img_bytes = base64.b64decode(b64)
result_img = Image.open(io.BytesIO(img_bytes)).convert("RGBA")
else:
result_url = _public_storage_url(result)
resp_img = requests.get(result_url, timeout=30)
resp_img.raise_for_status()
result_img = Image.open(io.BytesIO(resp_img.content)).convert("RGBA")
return result_img
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
description = "Upload a person photo (Base) and a product image. Select between Eyewear, Footwear, or Full-Body Garments to switch between the three available models. Click ๐Ÿ‘‰ **Generate** to try on a product." # noqa: E501
with gr.Blocks(title="YOURMIRROR.IO - SM4LL-VTON Demo") as demo:
# Header
gr.Markdown("# SM4LL-VTON PRE-RELEASE DEMO | YOURMIRROR.IO | Virtual Try-On")
gr.Markdown(description)
IMG_SIZE = 256
with gr.Row():
# Left column โ†’ inputs stacked vertically
with gr.Column(scale=1):
base_in = gr.Image(
label="Base Image",
type="pil",
height=IMG_SIZE,
width=IMG_SIZE,
)
garment_in = gr.Image(
label="Product Image",
type="pil",
height=IMG_SIZE,
width=IMG_SIZE,
)
# Centre column โ†’ result image (larger)
with gr.Column(scale=2):
result_out = gr.Image(
label="Result",
height=512,
width=512,
)
# Right column โ†’ workflow selector and Generate button
with gr.Column(scale=1, elem_classes="control-column"):
workflow_selector = gr.Radio(
choices=WORKFLOW_CHOICES,
value="eyewear",
label="Model",
)
generate_btn = gr.Button("Generate", variant="primary", size="lg")
# Wire up interaction
generate_btn.click(
generate,
inputs=[base_in, garment_in, workflow_selector],
outputs=result_out,
)
# Run app if executed directly (e.g. `python app.py`). HF Spaces launches via
# `python app.py` automatically if it finds `app.py` at repo root, but our file
# lives in a sub-folder, so we keep the guard.
if __name__ == "__main__":
demo.launch()