cd2412 commited on
Commit
278b37e
·
verified ·
1 Parent(s): a7b143d

Update controller.py

Browse files
Files changed (1) hide show
  1. controller.py +72 -11
controller.py CHANGED
@@ -1,13 +1,58 @@
1
- # controller.py
2
- """
3
- Stub implementation of the model controller.
4
- Replace `answer_with_controller` with your real inference pipeline
5
- (e.g., InstructBLIP + PPO gate + memory retrieval).
6
- """
7
 
 
 
8
  from PIL import Image
9
  from typing import Tuple
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def answer_with_controller(
12
  image: Image.Image,
13
  question: str,
@@ -17,9 +62,25 @@ def answer_with_controller(
17
  """
18
  Returns:
19
  pred (str): predicted answer
20
- strategy_name (str): chosen strategy
21
- action_id (int): numeric ID of strategy
22
  """
23
- # --- Dummy logic ---
24
- # Always returns "Demo placeholder answer" with baseline strategy
25
- return "Demo placeholder answer", "baseline", 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # controller.py — real VQA inference using BLIP (small, fast, no extra weights)
2
+ # Works on CPU Space. Uses HF Hub to download the model at first run.
 
 
 
 
3
 
4
+ import os
5
+ import torch
6
  from PIL import Image
7
  from typing import Tuple
8
 
9
+ from transformers import BlipForQuestionAnswering, BlipProcessor
10
+
11
+ # ---------------------------
12
+ # Load once at import time
13
+ # ---------------------------
14
+ HF_MODEL = os.getenv("HF_VQA_MODEL", "Salesforce/blip-vqa-base") # small & good
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ _processor = None
18
+ _model = None
19
+
20
+ def _load():
21
+ global _processor, _model
22
+ if _processor is None or _model is None:
23
+ _processor = BlipProcessor.from_pretrained(HF_MODEL)
24
+ _model = BlipForQuestionAnswering.from_pretrained(HF_MODEL)
25
+ _model.to(DEVICE)
26
+ _model.eval()
27
+
28
+ def _answer_baseline(image: Image.Image, question: str) -> str:
29
+ _load()
30
+ inputs = _processor(images=image, text=question, return_tensors="pt").to(DEVICE)
31
+ with torch.inference_mode():
32
+ out = _model.generate(**inputs, max_new_tokens=10)
33
+ ans = _processor.decode(out[0], skip_special_tokens=True)
34
+ return ans.strip()
35
+
36
+ # --- optional future hooks (no-ops for now, keep API stable) ---
37
+ def _answer_with_memory(image: Image.Image, question: str) -> str:
38
+ # Plug your FAISS/RAG here; fallback to baseline for now
39
+ return _answer_baseline(image, question)
40
+
41
+ def _gate_auto(image: Image.Image, question: str) -> Tuple[int, str]:
42
+ # When PPO or distilled are wired, pick actions here. For now: baseline (0).
43
+ return 0, "baseline"
44
+
45
+ def _gate_distilled(image: Image.Image, question: str) -> Tuple[int, str]:
46
+ # TODO: call your distilled classifier; fallback to baseline
47
+ return 0, "baseline"
48
+
49
+ def _gate_ppo(image: Image.Image, question: str) -> Tuple[int, str]:
50
+ # TODO: call your PPO policy; fallback to baseline
51
+ return 0, "baseline"
52
+
53
+ # ---------------------------
54
+ # Public API for app.py
55
+ # ---------------------------
56
  def answer_with_controller(
57
  image: Image.Image,
58
  question: str,
 
62
  """
63
  Returns:
64
  pred (str): predicted answer
65
+ strategy_name (str): chosen strategy name
66
+ action_id (int): numeric action (0=baseline, 1=memory in future, etc.)
67
  """
68
+ source = (source or "auto").lower()
69
+
70
+ if source == "baseline":
71
+ ans = _answer_baseline(image, question)
72
+ return ans, "baseline", 0
73
+ elif source == "distilled":
74
+ aid, label = _gate_distilled(image, question)
75
+ elif source == "ppo":
76
+ aid, label = _gate_ppo(image, question)
77
+ else: # auto
78
+ aid, label = _gate_auto(image, question)
79
+
80
+ # route by action id (for now all paths use baseline until you wire memory)
81
+ if aid == 1:
82
+ ans = _answer_with_memory(image, question)
83
+ else:
84
+ ans = _answer_baseline(image, question)
85
+
86
+ return ans, label, aid