File size: 5,302 Bytes
602cad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# app.py
"""
VQA β€” Memory + RL Controller (Gradio app)
- Drag-and-drop an image, ask a question, and see the model's answer + chosen strategy.
- Tries to import `answer_with_controller` from controller.py. Falls back to a stub if missing.
- Works on Hugging Face Spaces, Render, Docker, or local run.
"""

import os
import sys
import time
import traceback
import subprocess
from typing import Tuple, Optional

# Ensure gradio is available when running locally; Spaces installs from requirements.txt
try:
    import gradio as gr
except ImportError:  # pragma: no cover
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "gradio"])
    import gradio as gr

from PIL import Image

# -----------------------------
# Attempt to import real handler
# -----------------------------
def _make_fallback():
    def _fallback_answer_with_controller(
        image: Image.Image,
        question: str,
        source: str = "auto",
        distilled_model: str = "auto",
    ) -> Tuple[str, str, int]:
        # Replace with real inference to remove this placeholder.
        return "Placeholder answer (wire your models in controller.py).", "baseline", 0
    return _fallback_answer_with_controller

try:
    # Expect controller.py to define: answer_with_controller(image, question, source, distilled_model)
    from controller import answer_with_controller  # type: ignore
except Exception as e:
    print(f"[WARN] Using fallback controller because import failed: {e}", flush=True)
    answer_with_controller = _make_fallback()

# -----------------------------
# UI Constants
# -----------------------------
TITLE = "VQA β€” Memory + RL Controller"
DESCRIPTION = (
    "Upload an image, enter a question, and the controller will choose the best decoding strategy."
)

CONTROLLER_SOURCES = ["auto", "distilled", "ppo", "baseline"]
DISTILLED_CHOICES = ["auto", "logreg", "mlp32"]

# -----------------------------
# Inference wrapper with guards
# -----------------------------
def vqa_demo_fn(
    image: Optional[Image.Image],
    question: str,
    source: str,
    distilled_model: str,
) -> Tuple[str, str, float]:
    """Safely run inference and return (answer, strategy_label, latency_ms)."""
    # Input validation
    if image is None:
        return "Please upload an image.", "", 0.0
    question = (question or "").strip()
    if not question:
        return "Please enter a question.", "", 0.0

    # Convert & measure latency
    t0 = time.perf_counter()
    try:
        # Convert to RGB to avoid issues with PNG/L mode
        image_rgb = image.convert("RGB")

        pred, strategy_name, action_id = answer_with_controller(
            image_rgb,
            question,
            source=source,
            distilled_model=distilled_model,
        )

        latency_ms = (time.perf_counter() - t0) * 1000.0
        # Friendly formatting
        strategy_out = f"{action_id} β†’ {strategy_name}"
        return str(pred), strategy_out, round(latency_ms, 1)

    except Exception as err:
        # Never crash the app β€” show a concise error to the user and log details to server
        latency_ms = (time.perf_counter() - t0) * 1000.0
        print("[ERROR] Inference failed:\n" + "".join(traceback.format_exc()), flush=True)
        return f"Error: {err}", "error", round(latency_ms, 1)

# -----------------------------
# Build Gradio Interface
# -----------------------------
with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
    gr.Markdown(f"### {TITLE}\n{DESCRIPTION}")

    with gr.Row():
        with gr.Column():
            img_in = gr.Image(
                type="pil",
                label="Image",
                height=320,
                sources=["upload", "drag-and-drop", "clipboard", "webcam"],
                image_mode="RGB",
            )
            q_in = gr.Textbox(
                label="Question",
                placeholder="e.g., What colour is the bus?",
                lines=2,
                max_lines=4,
            )
            source_in = gr.Radio(
                CONTROLLER_SOURCES,
                value="auto",
                label="Controller Source",
            )
            dist_in = gr.Radio(
                DISTILLED_CHOICES,
                value="auto",
                label="Distilled Gate (if used)",
            )
            run_btn = gr.Button("Predict", variant="primary")
        with gr.Column():
            ans_out = gr.Textbox(label="Answer", interactive=False, lines=3, max_lines=6)
            strat_out = gr.Textbox(label="Chosen Strategy", interactive=False)
            lat_out = gr.Number(label="Latency (ms)", precision=1, interactive=False)

    run_btn.click(
        vqa_demo_fn,
        inputs=[img_in, q_in, source_in, dist_in],
        outputs=[ans_out, strat_out, lat_out],
        api_name="predict",
    )

# -----------------------------
# Launch
# -----------------------------
if __name__ == "__main__":
    # Respect $PORT for Spaces/Render/Docker; default to 7860 locally
    port = int(os.getenv("PORT", "7860"))
    # Queue improves robustness under load
    demo.queue(concurrency_count=2)
    demo.launch(
        server_name="0.0.0.0",
        server_port=port,
        share=False,          # set True only for local quick sharing
        show_error=True,
    )