Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import json | |
| import random | |
| from functools import lru_cache | |
| from typing import List, Tuple, Optional, Any | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient, hf_hub_download | |
| # ----------------------------------------------------------------------------- | |
| # Configuration | |
| # ----------------------------------------------------------------------------- | |
| # LoRAs in the "Kontext Dev LoRAs" collection. | |
| # NOTE: We hard-code the list for now. If the collection grows you can simply | |
| # append new model IDs here. | |
| LORA_MODELS: List[str] = [ | |
| # fal – original author | |
| "fal/Watercolor-Art-Kontext-Dev-LoRA", | |
| "fal/Pop-Art-Kontext-Dev-LoRA", | |
| "fal/Pencil-Drawing-Kontext-Dev-LoRA", | |
| "fal/Mosaic-Art-Kontext-Dev-LoRA", | |
| "fal/Minimalist-Art-Kontext-Dev-LoRA", | |
| "fal/Impressionist-Art-Kontext-Dev-LoRA", | |
| "fal/Gouache-Art-Kontext-Dev-LoRA", | |
| "fal/Expressive-Art-Kontext-Dev-LoRA", | |
| "fal/Cubist-Art-Kontext-Dev-LoRA", | |
| "fal/Collage-Art-Kontext-Dev-LoRA", | |
| "fal/Charcoal-Art-Kontext-Dev-LoRA", | |
| "fal/Acrylic-Art-Kontext-Dev-LoRA", | |
| "fal/Abstract-Art-Kontext-Dev-LoRA", | |
| "fal/Plushie-Kontext-Dev-LoRA", | |
| "fal/Youtube-Thumbnails-Kontext-Dev-LoRA", | |
| "fal/Broccoli-Hair-Kontext-Dev-LoRA", | |
| "fal/Wojak-Kontext-Dev-LoRA", | |
| "fal/3D-Game-Assets-Kontext-Dev-LoRA", | |
| "fal/Realism-Detailer-Kontext-Dev-LoRA", | |
| # community LoRAs | |
| "gokaygokay/Pencil-Drawing-Kontext-Dev-LoRA", | |
| "gokaygokay/Oil-Paint-Kontext-Dev-LoRA", | |
| "gokaygokay/Watercolor-Kontext-Dev-LoRA", | |
| "gokaygokay/Pastel-Flux-Kontext-Dev-LoRA", | |
| "gokaygokay/Low-Poly-Kontext-Dev-LoRA", | |
| "gokaygokay/Bronze-Sculpture-Kontext-Dev-LoRA", | |
| "gokaygokay/Marble-Sculpture-Kontext-Dev-LoRA", | |
| "gokaygokay/Light-Fix-Kontext-Dev-LoRA", | |
| "gokaygokay/Fuse-it-Kontext-Dev-LoRA", | |
| "ilkerzgi/Overlay-Kontext-Dev-LoRA", | |
| ] | |
| # Optional metadata cache file. Generated by `generate_lora_metadata.py`. | |
| METADATA_FILE = "lora_metadata.json" | |
| def _load_metadata() -> dict: | |
| """Load cached preview/trigger data if the JSON file exists.""" | |
| if os.path.exists(METADATA_FILE): | |
| try: | |
| with open(METADATA_FILE, "r", encoding="utf-8") as fp: | |
| return json.load(fp) | |
| except Exception: | |
| pass | |
| return {} | |
| # Token used for anonymous free quota | |
| FREE_TOKEN_ENV = "HF_TOKEN" | |
| FREE_REQUESTS = 10 | |
| # ----------------------------------------------------------------------------- | |
| # Utility helpers | |
| # ----------------------------------------------------------------------------- | |
| def get_client(token: str) -> InferenceClient: | |
| """Return cached InferenceClient instance for supplied token.""" | |
| return InferenceClient(provider="fal-ai", api_key=token) | |
| IMG_PATTERN = re.compile(r"!\[.*?\]\((.*?)\)") | |
| TRIGGER_PATTERN = re.compile(r"[Tt]rigger[^:]*:\s*([^\n]+)") | |
| def fetch_preview_and_trigger(model_id: str) -> Tuple[Optional[str], Optional[str]]: | |
| """Try to fetch a preview image URL and trigger phrase from the model card. | |
| If unsuccessful, returns (None, None). | |
| """ | |
| try: | |
| # Download README. | |
| readme_path = hf_hub_download(repo_id=model_id, filename="README.md") | |
| except Exception: | |
| return None, None | |
| image_url: Optional[str] = None | |
| trigger_phrase: Optional[str] = None | |
| try: | |
| with open(readme_path, "r", encoding="utf-8") as fp: | |
| text = fp.read() | |
| # First image in markdown → preview | |
| if (m := IMG_PATTERN.search(text)) is not None: | |
| img_path = m.group(1) | |
| if img_path.startswith("http"): | |
| image_url = img_path | |
| else: | |
| image_url = f"https://huggingface.co/{model_id}/resolve/main/{img_path.lstrip('./')}" | |
| # Try to parse trigger phrase | |
| if (m := TRIGGER_PATTERN.search(text)) is not None: | |
| trigger_phrase = m.group(1).strip() | |
| except Exception: | |
| pass | |
| return image_url, trigger_phrase | |
| # ----------------------------------------------------------------------------- | |
| # Core inference function | |
| # ----------------------------------------------------------------------------- | |
| def run_lora( | |
| input_image, # bytes or PIL.Image | |
| prompt: str, | |
| model_id: str, | |
| guidance_scale: float, | |
| token: str | None, | |
| req_count: int, | |
| ): | |
| """Execute image → image generation via selected LoRA.""" | |
| if input_image is None: | |
| raise gr.Error("Please provide an input image.") | |
| # Determine which token we will use | |
| if token: | |
| api_token = token | |
| else: | |
| free_token = os.getenv(FREE_TOKEN_ENV) | |
| if free_token is None: | |
| raise gr.Error("Service not configured for free usage. Please login.") | |
| if req_count >= FREE_REQUESTS: | |
| raise gr.Error("Free quota exceeded – please login with your own HF account to continue.") | |
| api_token = free_token | |
| client = get_client(api_token) | |
| # Gradio delivers PIL.Image by default. InferenceClient accepts bytes. | |
| if hasattr(input_image, "tobytes"): | |
| import io | |
| buf = io.BytesIO() | |
| input_image.save(buf, format="PNG") | |
| img_bytes = buf.getvalue() | |
| elif isinstance(input_image, bytes): | |
| img_bytes = input_image | |
| else: | |
| raise gr.Error("Unsupported image format.") | |
| output = client.image_to_image( | |
| img_bytes, | |
| prompt=prompt, | |
| model=model_id, | |
| guidance_scale=guidance_scale, | |
| ) | |
| # Update request count only if using free token | |
| new_count = req_count if token else req_count + 1 | |
| return output, new_count, f"Free requests remaining: {max(0, FREE_REQUESTS - new_count)}" if not token else "Logged in ✅ Unlimited" | |
| # ----------------------------------------------------------------------------- | |
| # UI assembly | |
| # ----------------------------------------------------------------------------- | |
| def build_interface(): | |
| # Pre-load metadata into closure for fast look-ups. | |
| metadata_cache = _load_metadata() | |
| # Theme & CSS | |
| theme = gr.themes.Soft(primary_hue="violet", secondary_hue="indigo") | |
| custom_css = """ | |
| .gradio-container {max-width: 980px; margin: auto;} | |
| .gallery-item {border-radius: 8px; overflow: hidden;} | |
| """ | |
| with gr.Blocks(title="Kontext-Dev LoRA Playground", theme=theme, css=custom_css) as demo: | |
| token_state = gr.State(value="") | |
| request_count_state = gr.State(value=0) | |
| # --- Authentication UI ------------------------------------------- | |
| if hasattr(gr, "LoginButton"): | |
| login_btn = gr.LoginButton() | |
| token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})") | |
| def _handle_login(login_data: Any): | |
| """Extract HF token from login payload returned by LoginButton.""" | |
| token: str = "" | |
| if isinstance(login_data, dict): | |
| token = login_data.get("access_token") or login_data.get("token") or "" | |
| elif isinstance(login_data, str): | |
| token = login_data | |
| status = "Logged in ✅ Unlimited" if token else f"Not logged in – using free quota (max {FREE_REQUESTS})" | |
| return token, status | |
| login_btn.login(_handle_login, outputs=[token_state, token_status]) | |
| else: | |
| # Fallback manual token input if LoginButton not available (local dev) | |
| with gr.Accordion("🔑 Paste your HF token (optional)", open=False): | |
| token_input = gr.Textbox(label="HF Token", type="password", placeholder="Paste your token here…") | |
| save_token_btn = gr.Button("Save token") | |
| token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})") | |
| # Handlers to store token | |
| def _save_token(tok): | |
| return tok or "" | |
| def _token_status(tok): | |
| return "Logged in ✅ Unlimited" if tok else f"Not logged in – using free quota (max {FREE_REQUESTS})" | |
| save_token_btn.click(_save_token, inputs=token_input, outputs=token_state) | |
| save_token_btn.click(_token_status, inputs=token_input, outputs=token_status) | |
| gr.Markdown( | |
| """ | |
| # Kontext-Dev LoRA Playground | |
| Select one of the available LoRAs from the dropdown, upload an image, tweak the prompt, and generate! | |
| """ | |
| ) | |
| with gr.Row(): | |
| # LEFT column – model selection + preview | |
| with gr.Column(scale=1): | |
| model_dropdown = gr.Dropdown( | |
| choices=LORA_MODELS, | |
| value=LORA_MODELS[0], | |
| label="Select LoRA model", | |
| ) | |
| preview_image = gr.Image(label="Sample image", interactive=False, height=256) | |
| trigger_text = gr.Textbox( | |
| label="Trigger phrase (suggested)", | |
| interactive=False, | |
| ) | |
| # RIGHT column – user inputs | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="Input image", | |
| type="pil", | |
| ) | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe your transformation…", | |
| ) | |
| guidance = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=2.5, | |
| step=0.1, | |
| label="Guidance scale", | |
| ) | |
| generate_btn = gr.Button("🚀 Generate") | |
| output_image = gr.Image(label="Output", interactive=False) | |
| quota_display = gr.Markdown(value=f"Free requests remaining: {FREE_REQUESTS}") | |
| # Showcase Gallery -------------------------------------------------- | |
| gr.Markdown("## ✨ Example outputs from selected LoRAs") | |
| example_gallery = gr.Gallery( | |
| label="Examples", | |
| columns=[4], | |
| height="auto", | |
| elem_id="example_gallery", | |
| ) | |
| gallery_data_state = gr.State([]) | |
| # ------------------------------------------------------------------ | |
| # Callbacks | |
| # ------------------------------------------------------------------ | |
| def _update_preview(model_id, _meta=metadata_cache): | |
| if model_id in _meta: | |
| img_url = _meta[model_id].get("image_url") | |
| trig = _meta[model_id].get("trigger_phrase") | |
| else: | |
| img_url, trig = fetch_preview_and_trigger(model_id) | |
| # Fallbacks | |
| if trig is None: | |
| trig = "(no trigger phrase provided)" | |
| return { | |
| preview_image: gr.Image(value=img_url) if img_url else gr.Image(value=None), | |
| trigger_text: gr.Textbox(value=trig), | |
| prompt_box: gr.Textbox(value=trig), | |
| } | |
| model_dropdown.change(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box]) | |
| generate_btn.click( | |
| fn=run_lora, | |
| inputs=[input_image, prompt_box, model_dropdown, guidance, token_state, request_count_state], | |
| outputs=[output_image, request_count_state, quota_display], | |
| ) | |
| # Helper to populate gallery once on launch | |
| def _load_gallery(_meta=metadata_cache): | |
| samples = [] | |
| for model_id in LORA_MODELS: | |
| info = _meta.get(model_id) | |
| if info and info.get("image_url"): | |
| samples.append([info["image_url"], model_id]) | |
| # shuffle and take first 12 | |
| random.shuffle(samples) | |
| return samples[:12], samples[:12] | |
| # Initialise preview and gallery on launch | |
| demo.load(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box]) | |
| demo.load(fn=_load_gallery, inputs=None, outputs=[example_gallery, gallery_data_state]) | |
| # Handle gallery click to update dropdown | |
| def _on_gallery_select(evt: gr.SelectData, data): | |
| idx = evt.index | |
| if idx is None or idx >= len(data): | |
| return gr.Dropdown.update() | |
| model_id = data[idx][1] | |
| return gr.Dropdown.update(value=model_id) | |
| example_gallery.select(_on_gallery_select, inputs=gallery_data_state, outputs=model_dropdown) | |
| return demo | |
| def main(): | |
| demo = build_interface() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |