Spaces:
Sleeping
Sleeping
File size: 5,302 Bytes
602cad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# app.py
"""
VQA β Memory + RL Controller (Gradio app)
- Drag-and-drop an image, ask a question, and see the model's answer + chosen strategy.
- Tries to import `answer_with_controller` from controller.py. Falls back to a stub if missing.
- Works on Hugging Face Spaces, Render, Docker, or local run.
"""
import os
import sys
import time
import traceback
import subprocess
from typing import Tuple, Optional
# Ensure gradio is available when running locally; Spaces installs from requirements.txt
try:
import gradio as gr
except ImportError: # pragma: no cover
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "gradio"])
import gradio as gr
from PIL import Image
# -----------------------------
# Attempt to import real handler
# -----------------------------
def _make_fallback():
def _fallback_answer_with_controller(
image: Image.Image,
question: str,
source: str = "auto",
distilled_model: str = "auto",
) -> Tuple[str, str, int]:
# Replace with real inference to remove this placeholder.
return "Placeholder answer (wire your models in controller.py).", "baseline", 0
return _fallback_answer_with_controller
try:
# Expect controller.py to define: answer_with_controller(image, question, source, distilled_model)
from controller import answer_with_controller # type: ignore
except Exception as e:
print(f"[WARN] Using fallback controller because import failed: {e}", flush=True)
answer_with_controller = _make_fallback()
# -----------------------------
# UI Constants
# -----------------------------
TITLE = "VQA β Memory + RL Controller"
DESCRIPTION = (
"Upload an image, enter a question, and the controller will choose the best decoding strategy."
)
CONTROLLER_SOURCES = ["auto", "distilled", "ppo", "baseline"]
DISTILLED_CHOICES = ["auto", "logreg", "mlp32"]
# -----------------------------
# Inference wrapper with guards
# -----------------------------
def vqa_demo_fn(
image: Optional[Image.Image],
question: str,
source: str,
distilled_model: str,
) -> Tuple[str, str, float]:
"""Safely run inference and return (answer, strategy_label, latency_ms)."""
# Input validation
if image is None:
return "Please upload an image.", "", 0.0
question = (question or "").strip()
if not question:
return "Please enter a question.", "", 0.0
# Convert & measure latency
t0 = time.perf_counter()
try:
# Convert to RGB to avoid issues with PNG/L mode
image_rgb = image.convert("RGB")
pred, strategy_name, action_id = answer_with_controller(
image_rgb,
question,
source=source,
distilled_model=distilled_model,
)
latency_ms = (time.perf_counter() - t0) * 1000.0
# Friendly formatting
strategy_out = f"{action_id} β {strategy_name}"
return str(pred), strategy_out, round(latency_ms, 1)
except Exception as err:
# Never crash the app β show a concise error to the user and log details to server
latency_ms = (time.perf_counter() - t0) * 1000.0
print("[ERROR] Inference failed:\n" + "".join(traceback.format_exc()), flush=True)
return f"Error: {err}", "error", round(latency_ms, 1)
# -----------------------------
# Build Gradio Interface
# -----------------------------
with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
gr.Markdown(f"### {TITLE}\n{DESCRIPTION}")
with gr.Row():
with gr.Column():
img_in = gr.Image(
type="pil",
label="Image",
height=320,
sources=["upload", "drag-and-drop", "clipboard", "webcam"],
image_mode="RGB",
)
q_in = gr.Textbox(
label="Question",
placeholder="e.g., What colour is the bus?",
lines=2,
max_lines=4,
)
source_in = gr.Radio(
CONTROLLER_SOURCES,
value="auto",
label="Controller Source",
)
dist_in = gr.Radio(
DISTILLED_CHOICES,
value="auto",
label="Distilled Gate (if used)",
)
run_btn = gr.Button("Predict", variant="primary")
with gr.Column():
ans_out = gr.Textbox(label="Answer", interactive=False, lines=3, max_lines=6)
strat_out = gr.Textbox(label="Chosen Strategy", interactive=False)
lat_out = gr.Number(label="Latency (ms)", precision=1, interactive=False)
run_btn.click(
vqa_demo_fn,
inputs=[img_in, q_in, source_in, dist_in],
outputs=[ans_out, strat_out, lat_out],
api_name="predict",
)
# -----------------------------
# Launch
# -----------------------------
if __name__ == "__main__":
# Respect $PORT for Spaces/Render/Docker; default to 7860 locally
port = int(os.getenv("PORT", "7860"))
# Queue improves robustness under load
demo.queue(concurrency_count=2)
demo.launch(
server_name="0.0.0.0",
server_port=port,
share=False, # set True only for local quick sharing
show_error=True,
)
|