Update controller.py
Browse files- controller.py +72 -11
controller.py
CHANGED
|
@@ -1,13 +1,58 @@
|
|
| 1 |
-
# controller.py
|
| 2 |
-
|
| 3 |
-
Stub implementation of the model controller.
|
| 4 |
-
Replace `answer_with_controller` with your real inference pipeline
|
| 5 |
-
(e.g., InstructBLIP + PPO gate + memory retrieval).
|
| 6 |
-
"""
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
from typing import Tuple
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def answer_with_controller(
|
| 12 |
image: Image.Image,
|
| 13 |
question: str,
|
|
@@ -17,9 +62,25 @@ def answer_with_controller(
|
|
| 17 |
"""
|
| 18 |
Returns:
|
| 19 |
pred (str): predicted answer
|
| 20 |
-
strategy_name (str): chosen strategy
|
| 21 |
-
action_id (int): numeric
|
| 22 |
"""
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# controller.py — real VQA inference using BLIP (small, fast, no extra weights)
|
| 2 |
+
# Works on CPU Space. Uses HF Hub to download the model at first run.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
from PIL import Image
|
| 7 |
from typing import Tuple
|
| 8 |
|
| 9 |
+
from transformers import BlipForQuestionAnswering, BlipProcessor
|
| 10 |
+
|
| 11 |
+
# ---------------------------
|
| 12 |
+
# Load once at import time
|
| 13 |
+
# ---------------------------
|
| 14 |
+
HF_MODEL = os.getenv("HF_VQA_MODEL", "Salesforce/blip-vqa-base") # small & good
|
| 15 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
|
| 17 |
+
_processor = None
|
| 18 |
+
_model = None
|
| 19 |
+
|
| 20 |
+
def _load():
|
| 21 |
+
global _processor, _model
|
| 22 |
+
if _processor is None or _model is None:
|
| 23 |
+
_processor = BlipProcessor.from_pretrained(HF_MODEL)
|
| 24 |
+
_model = BlipForQuestionAnswering.from_pretrained(HF_MODEL)
|
| 25 |
+
_model.to(DEVICE)
|
| 26 |
+
_model.eval()
|
| 27 |
+
|
| 28 |
+
def _answer_baseline(image: Image.Image, question: str) -> str:
|
| 29 |
+
_load()
|
| 30 |
+
inputs = _processor(images=image, text=question, return_tensors="pt").to(DEVICE)
|
| 31 |
+
with torch.inference_mode():
|
| 32 |
+
out = _model.generate(**inputs, max_new_tokens=10)
|
| 33 |
+
ans = _processor.decode(out[0], skip_special_tokens=True)
|
| 34 |
+
return ans.strip()
|
| 35 |
+
|
| 36 |
+
# --- optional future hooks (no-ops for now, keep API stable) ---
|
| 37 |
+
def _answer_with_memory(image: Image.Image, question: str) -> str:
|
| 38 |
+
# Plug your FAISS/RAG here; fallback to baseline for now
|
| 39 |
+
return _answer_baseline(image, question)
|
| 40 |
+
|
| 41 |
+
def _gate_auto(image: Image.Image, question: str) -> Tuple[int, str]:
|
| 42 |
+
# When PPO or distilled are wired, pick actions here. For now: baseline (0).
|
| 43 |
+
return 0, "baseline"
|
| 44 |
+
|
| 45 |
+
def _gate_distilled(image: Image.Image, question: str) -> Tuple[int, str]:
|
| 46 |
+
# TODO: call your distilled classifier; fallback to baseline
|
| 47 |
+
return 0, "baseline"
|
| 48 |
+
|
| 49 |
+
def _gate_ppo(image: Image.Image, question: str) -> Tuple[int, str]:
|
| 50 |
+
# TODO: call your PPO policy; fallback to baseline
|
| 51 |
+
return 0, "baseline"
|
| 52 |
+
|
| 53 |
+
# ---------------------------
|
| 54 |
+
# Public API for app.py
|
| 55 |
+
# ---------------------------
|
| 56 |
def answer_with_controller(
|
| 57 |
image: Image.Image,
|
| 58 |
question: str,
|
|
|
|
| 62 |
"""
|
| 63 |
Returns:
|
| 64 |
pred (str): predicted answer
|
| 65 |
+
strategy_name (str): chosen strategy name
|
| 66 |
+
action_id (int): numeric action (0=baseline, 1=memory in future, etc.)
|
| 67 |
"""
|
| 68 |
+
source = (source or "auto").lower()
|
| 69 |
+
|
| 70 |
+
if source == "baseline":
|
| 71 |
+
ans = _answer_baseline(image, question)
|
| 72 |
+
return ans, "baseline", 0
|
| 73 |
+
elif source == "distilled":
|
| 74 |
+
aid, label = _gate_distilled(image, question)
|
| 75 |
+
elif source == "ppo":
|
| 76 |
+
aid, label = _gate_ppo(image, question)
|
| 77 |
+
else: # auto
|
| 78 |
+
aid, label = _gate_auto(image, question)
|
| 79 |
+
|
| 80 |
+
# route by action id (for now all paths use baseline until you wire memory)
|
| 81 |
+
if aid == 1:
|
| 82 |
+
ans = _answer_with_memory(image, question)
|
| 83 |
+
else:
|
| 84 |
+
ans = _answer_baseline(image, question)
|
| 85 |
+
|
| 86 |
+
return ans, label, aid
|