cd2412 commited on
Commit
cb277bc
·
verified ·
1 Parent(s): 72a1f41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -93
app.py CHANGED
@@ -1,131 +1,63 @@
1
- # app.py
2
- """
3
- VQA — Memory + RL Controller (Gradio app)
4
- - Drag-and-drop an image, ask a question, and see the model's answer + chosen strategy.
5
- - Tries to import `answer_with_controller` from controller.py. Falls back to a stub if missing.
6
- - Works on Hugging Face Spaces, Render, Docker, or local run.
7
- """
8
-
9
- import os
10
- import sys
11
- import time
12
- import traceback
13
- import subprocess
14
  from typing import Tuple, Optional
 
15
 
16
- # Ensure gradio is available when running locally; Spaces installs from requirements.txt
17
  try:
18
  import gradio as gr
19
- except ImportError: # pragma: no cover
20
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "gradio"])
21
  import gradio as gr
22
 
23
- from PIL import Image
24
-
25
- # -----------------------------
26
- # Attempt to import real handler
27
- # -----------------------------
28
  def _make_fallback():
29
- def _fallback_answer_with_controller(
30
- image: Image.Image,
31
- question: str,
32
- source: str = "auto",
33
- distilled_model: str = "auto",
34
- ) -> Tuple[str, str, int]:
35
- # Replace with real inference to remove this placeholder.
36
  return "Placeholder answer (wire your models in controller.py).", "baseline", 0
37
  return _fallback_answer_with_controller
38
 
39
  try:
40
- # Expect controller.py to define: answer_with_controller(image, question, source, distilled_model)
41
- from controller import answer_with_controller # type: ignore
42
  except Exception as e:
43
  print(f"[WARN] Using fallback controller because import failed: {e}", flush=True)
44
  answer_with_controller = _make_fallback()
45
 
46
- # -----------------------------
47
- # UI Constants
48
- # -----------------------------
49
  TITLE = "VQA — Memory + RL Controller"
50
- DESCRIPTION = (
51
- "Upload an image, enter a question, and the controller will choose the best decoding strategy."
52
- )
53
 
54
  CONTROLLER_SOURCES = ["auto", "distilled", "ppo", "baseline"]
55
  DISTILLED_CHOICES = ["auto", "logreg", "mlp32"]
56
 
57
- # -----------------------------
58
- # Inference wrapper with guards
59
- # -----------------------------
60
- def vqa_demo_fn(
61
- image: Optional[Image.Image],
62
- question: str,
63
- source: str,
64
- distilled_model: str,
65
- ) -> Tuple[str, str, float]:
66
- """Safely run inference and return (answer, strategy_label, latency_ms)."""
67
- # Input validation
68
  if image is None:
69
  return "Please upload an image.", "", 0.0
70
  question = (question or "").strip()
71
  if not question:
72
  return "Please enter a question.", "", 0.0
73
-
74
- # Convert & measure latency
75
  t0 = time.perf_counter()
76
  try:
77
- # Convert to RGB to avoid issues with PNG/L mode
78
  image_rgb = image.convert("RGB")
79
-
80
  pred, strategy_name, action_id = answer_with_controller(
81
- image_rgb,
82
- question,
83
- source=source,
84
- distilled_model=distilled_model,
85
  )
86
-
87
  latency_ms = (time.perf_counter() - t0) * 1000.0
88
- # Friendly formatting
89
- strategy_out = f"{action_id} → {strategy_name}"
90
- return str(pred), strategy_out, round(latency_ms, 1)
91
-
92
  except Exception as err:
93
- # Never crash the app — show a concise error to the user and log details to server
94
  latency_ms = (time.perf_counter() - t0) * 1000.0
95
  print("[ERROR] Inference failed:\n" + "".join(traceback.format_exc()), flush=True)
96
  return f"Error: {err}", "error", round(latency_ms, 1)
97
 
98
- # -----------------------------
99
- # Build Gradio Interface
100
- # -----------------------------
101
  with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
102
  gr.Markdown(f"### {TITLE}\n{DESCRIPTION}")
103
-
104
  with gr.Row():
105
  with gr.Column():
106
  img_in = gr.Image(
107
  type="pil",
108
  label="Image",
109
  height=320,
110
- sources=["upload", "drag-and-drop", "clipboard", "webcam"],
111
- image_mode="RGB",
112
- )
113
- q_in = gr.Textbox(
114
- label="Question",
115
- placeholder="e.g., What colour is the bus?",
116
- lines=2,
117
- max_lines=4,
118
- )
119
- source_in = gr.Radio(
120
- CONTROLLER_SOURCES,
121
- value="auto",
122
- label="Controller Source",
123
- )
124
- dist_in = gr.Radio(
125
- DISTILLED_CHOICES,
126
- value="auto",
127
- label="Distilled Gate (if used)",
128
  )
 
 
 
129
  run_btn = gr.Button("Predict", variant="primary")
130
  with gr.Column():
131
  ans_out = gr.Textbox(label="Answer", interactive=False, lines=3, max_lines=6)
@@ -139,17 +71,7 @@ with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
139
  api_name="predict",
140
  )
141
 
142
- # -----------------------------
143
- # Launch
144
- # -----------------------------
145
  if __name__ == "__main__":
146
- # Respect $PORT for Spaces/Render/Docker; default to 7860 locally
147
  port = int(os.getenv("PORT", "7860"))
148
- # Queue improves robustness under load
149
  demo.queue(concurrency_count=2)
150
- demo.launch(
151
- server_name="0.0.0.0",
152
- server_port=port,
153
- share=False, # set True only for local quick sharing
154
- show_error=True,
155
- )
 
1
+ # app.py (fixed)
2
+ import os, sys, time, traceback, subprocess
 
 
 
 
 
 
 
 
 
 
 
3
  from typing import Tuple, Optional
4
+ from PIL import Image
5
 
 
6
  try:
7
  import gradio as gr
8
+ except ImportError:
9
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "gradio"])
10
  import gradio as gr
11
 
 
 
 
 
 
12
  def _make_fallback():
13
+ def _fallback_answer_with_controller(image, question, source="auto", distilled_model="auto"):
 
 
 
 
 
 
14
  return "Placeholder answer (wire your models in controller.py).", "baseline", 0
15
  return _fallback_answer_with_controller
16
 
17
  try:
18
+ from controller import answer_with_controller
 
19
  except Exception as e:
20
  print(f"[WARN] Using fallback controller because import failed: {e}", flush=True)
21
  answer_with_controller = _make_fallback()
22
 
 
 
 
23
  TITLE = "VQA — Memory + RL Controller"
24
+ DESCRIPTION = "Upload an image, enter a question, and the controller will choose the best decoding strategy."
 
 
25
 
26
  CONTROLLER_SOURCES = ["auto", "distilled", "ppo", "baseline"]
27
  DISTILLED_CHOICES = ["auto", "logreg", "mlp32"]
28
 
29
+ def vqa_demo_fn(image: Optional[Image.Image], question: str, source: str, distilled_model: str) -> Tuple[str, str, float]:
 
 
 
 
 
 
 
 
 
 
30
  if image is None:
31
  return "Please upload an image.", "", 0.0
32
  question = (question or "").strip()
33
  if not question:
34
  return "Please enter a question.", "", 0.0
 
 
35
  t0 = time.perf_counter()
36
  try:
 
37
  image_rgb = image.convert("RGB")
 
38
  pred, strategy_name, action_id = answer_with_controller(
39
+ image_rgb, question, source=source, distilled_model=distilled_model
 
 
 
40
  )
 
41
  latency_ms = (time.perf_counter() - t0) * 1000.0
42
+ return str(pred), f"{action_id} → {strategy_name}", round(latency_ms, 1)
 
 
 
43
  except Exception as err:
 
44
  latency_ms = (time.perf_counter() - t0) * 1000.0
45
  print("[ERROR] Inference failed:\n" + "".join(traceback.format_exc()), flush=True)
46
  return f"Error: {err}", "error", round(latency_ms, 1)
47
 
 
 
 
48
  with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
49
  gr.Markdown(f"### {TITLE}\n{DESCRIPTION}")
 
50
  with gr.Row():
51
  with gr.Column():
52
  img_in = gr.Image(
53
  type="pil",
54
  label="Image",
55
  height=320,
56
+ sources=["upload", "webcam", "clipboard"], # ✅ fixed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
+ q_in = gr.Textbox(label="Question", placeholder="e.g., What colour is the bus?", lines=2, max_lines=4)
59
+ source_in = gr.Radio(CONTROLLER_SOURCES, value="auto", label="Controller Source")
60
+ dist_in = gr.Radio(DISTILLED_CHOICES, value="auto", label="Distilled Gate (if used)")
61
  run_btn = gr.Button("Predict", variant="primary")
62
  with gr.Column():
63
  ans_out = gr.Textbox(label="Answer", interactive=False, lines=3, max_lines=6)
 
71
  api_name="predict",
72
  )
73
 
 
 
 
74
  if __name__ == "__main__":
 
75
  port = int(os.getenv("PORT", "7860"))
 
76
  demo.queue(concurrency_count=2)
77
+ demo.launch(server_name="0.0.0.0", server_port=port, share=False, show_error=True)