Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- detr_and_interp.py +442 -0
- requirements.txt +14 -0
detr_and_interp.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
this is a combined script that implements DETR object detection with interpretability methods
|
| 3 |
+
using Grad-CAM, Grad-CAM++, Integrated Gradients, and Monte Carlo Dropout for uncertainty estimation.
|
| 4 |
+
It provides a Gradio-based web interface for users to upload images, select detected objects
|
| 5 |
+
and visualize explanations and uncertainty maps.
|
| 6 |
+
|
| 7 |
+
How to run it:
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
python detr_and_interp.py
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
import torch, requests, numpy as np
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
import matplotlib.patches as patches
|
| 18 |
+
from PIL import Image, ImageFilter
|
| 19 |
+
import gradio as gr
|
| 20 |
+
from transformers import DetrImageProcessor, DetrForObjectDetection
|
| 21 |
+
from torchvision.transforms.functional import resize
|
| 22 |
+
from captum.attr import IntegratedGradients
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import logging
|
| 25 |
+
import os
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
|
| 28 |
+
# ---------- Logging Setup ----------
|
| 29 |
+
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
| 30 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 31 |
+
log_file = os.path.join(log_dir, f"detr_interp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
|
| 36 |
+
handlers=[
|
| 37 |
+
logging.FileHandler(log_file),
|
| 38 |
+
logging.StreamHandler()
|
| 39 |
+
]
|
| 40 |
+
)
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
logger.info("Starting DETR Interpretability Dashboard")
|
| 44 |
+
|
| 45 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
+
logger.info(f"Using device: {device}")
|
| 47 |
+
|
| 48 |
+
model_name = "facebook/detr-resnet-50"
|
| 49 |
+
logger.info(f"Loading model: {model_name}")
|
| 50 |
+
model = DetrForObjectDetection.from_pretrained(model_name).to(device)
|
| 51 |
+
extractor = DetrImageProcessor.from_pretrained(model_name)
|
| 52 |
+
model.eval()
|
| 53 |
+
logger.info("Model loaded and set to evaluation mode")
|
| 54 |
+
|
| 55 |
+
# ---------- Grad-CAM / Grad-CAM++ ----------
|
| 56 |
+
def gradcam(img, det_idx, keep, pixel_values, use_pp=False):
|
| 57 |
+
"""
|
| 58 |
+
Compute Grad-CAM (or Grad-CAM++) heatmap for a selected detection.
|
| 59 |
+
|
| 60 |
+
What it computes:
|
| 61 |
+
- Captures feature-map activations from a late conv layer and the gradients of the
|
| 62 |
+
detection score w.r.t. those activations. Channel-wise weights are computed from
|
| 63 |
+
gradients and used to combine feature maps into a spatial heatmap.
|
| 64 |
+
|
| 65 |
+
Why this matters:
|
| 66 |
+
- Highlights which spatial regions the model used to make the prediction. Useful to
|
| 67 |
+
check whether the detector is attending to the object vs irrelevant background.
|
| 68 |
+
|
| 69 |
+
How to interpret results:
|
| 70 |
+
- High values in the returned heatmap indicate regions that contributed positively to
|
| 71 |
+
the detection score. Grad-CAM++ (use_pp=True) computes a refined weighting that often
|
| 72 |
+
yields sharper, better-localized maps when multiple instances overlap.
|
| 73 |
+
|
| 74 |
+
Caveats & tips:
|
| 75 |
+
- Choosing a layer too early will give fine-grained but semantically weak maps; too late
|
| 76 |
+
will be coarse. We pick a late backbone conv block (layer4[-1]) as a sensible default.
|
| 77 |
+
- Hooks must be removed after use to avoid memory leaks; we do that below.
|
| 78 |
+
|
| 79 |
+
References:
|
| 80 |
+
- Selvaraju et al., Grad-CAM (2017): https://arxiv.org/abs/1610.02391
|
| 81 |
+
"""
|
| 82 |
+
logger.info(f"Running {'Grad-CAM++' if use_pp else 'Grad-CAM'} for detection {det_idx}")
|
| 83 |
+
try:
|
| 84 |
+
# pick a late conv layer that still retains spatial info
|
| 85 |
+
conv_layer = model.model.backbone.conv_encoder.model.layer4[-1]
|
| 86 |
+
activations, gradients = {}, {}
|
| 87 |
+
|
| 88 |
+
def fwd(m, i, o):
|
| 89 |
+
activations["v"] = o.detach()
|
| 90 |
+
|
| 91 |
+
def bwd(m, gi, go):
|
| 92 |
+
gradients["v"] = go[0].detach()
|
| 93 |
+
|
| 94 |
+
h1 = conv_layer.register_forward_hook(fwd)
|
| 95 |
+
h2 = conv_layer.register_full_backward_hook(bwd) if hasattr(conv_layer, "register_full_backward_hook") else conv_layer.register_backward_hook(bwd)
|
| 96 |
+
logger.debug("Hooks registered for Grad-CAM")
|
| 97 |
+
|
| 98 |
+
outputs_for_attr = model(pixel_values)
|
| 99 |
+
logits = outputs_for_attr.logits
|
| 100 |
+
labels = logits.argmax(-1).squeeze(0)
|
| 101 |
+
label_id = labels[keep.nonzero()[det_idx]].item()
|
| 102 |
+
score = logits[0, keep.nonzero()[det_idx], label_id]
|
| 103 |
+
logger.debug(f"Target label_id: {label_id}, score: {score.item():.4f}")
|
| 104 |
+
|
| 105 |
+
model.zero_grad()
|
| 106 |
+
score.backward()
|
| 107 |
+
|
| 108 |
+
acts = activations["v"].squeeze(0)
|
| 109 |
+
grads = gradients["v"].squeeze(0)
|
| 110 |
+
logger.debug(f"Activations shape: {acts.shape}, Gradients shape: {grads.shape}")
|
| 111 |
+
|
| 112 |
+
if use_pp: # Grad-CAM++
|
| 113 |
+
weights = (grads ** 2).mean(dim=(1, 2)) / (2 * (grads ** 2).mean(dim=(1, 2)) + (acts * grads ** 3).mean(dim=(1, 2)) + 1e-8)
|
| 114 |
+
else: # vanilla Grad-CAM
|
| 115 |
+
weights = grads.mean(dim=(1, 2))
|
| 116 |
+
|
| 117 |
+
cam = torch.relu((weights[:, None, None] * acts).sum(0))
|
| 118 |
+
cam = cam / (cam.max() + 1e-8)
|
| 119 |
+
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy()
|
| 120 |
+
|
| 121 |
+
h1.remove(); h2.remove()
|
| 122 |
+
logger.info(f"{'Grad-CAM++' if use_pp else 'Grad-CAM'} completed successfully")
|
| 123 |
+
return cam_resized
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"Error in gradcam: {str(e)}", exc_info=True)
|
| 126 |
+
raise
|
| 127 |
+
|
| 128 |
+
# ---------- Integrated Gradients ----------
|
| 129 |
+
def integrated_grad(img, det_idx, keep, outputs_for_attr, pixel_values, baseline="black"):
|
| 130 |
+
"""
|
| 131 |
+
Compute Integrated Gradients attribution map for a detection's logit.
|
| 132 |
+
|
| 133 |
+
What it computes:
|
| 134 |
+
- Integrates gradients along a path from a baseline input to the real input in embedding
|
| 135 |
+
space, producing per-pixel (or per-channel) attributions.
|
| 136 |
+
|
| 137 |
+
Why baseline choice matters:
|
| 138 |
+
- The baseline defines what the model should consider as 'no signal'. Common choices:
|
| 139 |
+
black (zeros), a blurred version of the image, or a neutral/mean image. Different
|
| 140 |
+
baselines highlight different aspects of the input.
|
| 141 |
+
|
| 142 |
+
How to read the output:
|
| 143 |
+
- Values > 0 indicate pixels that increase the detection logit vs baseline; values < 0
|
| 144 |
+
reduce it. We normalize the result to [0,1] for visualization convenience.
|
| 145 |
+
|
| 146 |
+
Tips:
|
| 147 |
+
- Increase n_steps for smoother attributions (costlier). Check convergence_delta to
|
| 148 |
+
validate IG's completeness property.
|
| 149 |
+
|
| 150 |
+
References:
|
| 151 |
+
- Distill article on baselines: https://distill.pub/2020/attribution-baselines
|
| 152 |
+
- Captum IntegratedGradients docs: https://captum.ai/api/integrated_gradients.html
|
| 153 |
+
"""
|
| 154 |
+
logger.info(f"Running Integrated Gradients with {baseline} baseline for detection {det_idx}")
|
| 155 |
+
try:
|
| 156 |
+
logits = outputs_for_attr.logits
|
| 157 |
+
labels = logits.argmax(-1).squeeze(0)
|
| 158 |
+
label_id = labels[keep.nonzero()[det_idx]].item()
|
| 159 |
+
logger.debug(f"IG target label_id: {label_id}")
|
| 160 |
+
|
| 161 |
+
# Baselines
|
| 162 |
+
if baseline == "black":
|
| 163 |
+
base = torch.zeros_like(pixel_values)
|
| 164 |
+
logger.debug("Using black baseline")
|
| 165 |
+
elif baseline == "blur":
|
| 166 |
+
blur = img.filter(ImageFilter.GaussianBlur(radius=15))
|
| 167 |
+
base = extractor(images=blur, return_tensors="pt")["pixel_values"].to(device)
|
| 168 |
+
logger.debug("Using blurred baseline")
|
| 169 |
+
else:
|
| 170 |
+
base = torch.zeros_like(pixel_values)
|
| 171 |
+
logger.debug("Defaulting to black baseline")
|
| 172 |
+
|
| 173 |
+
def forward_func(pix):
|
| 174 |
+
return model(pix).logits[:, keep.nonzero()[det_idx], label_id]
|
| 175 |
+
|
| 176 |
+
ig = IntegratedGradients(forward_func)
|
| 177 |
+
attr, _ = ig.attribute(pixel_values, baselines=base, n_steps=25, return_convergence_delta=True)
|
| 178 |
+
arr = attr.squeeze().mean(0).cpu().detach().numpy()
|
| 179 |
+
logger.info(f"Integrated Gradients with {baseline} baseline completed")
|
| 180 |
+
return (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.error(f"Error in integrated_grad: {str(e)}", exc_info=True)
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
# ---------- Monte Carlo Dropout Uncertainty ----------
|
| 186 |
+
def mc_dropout_uncertainty(img, det_idx, keep, pixel_values, n_samples=20, dropout_p=0.1):
|
| 187 |
+
"""
|
| 188 |
+
Estimate uncertainty by running multiple stochastic forward passes with dropout active.
|
| 189 |
+
|
| 190 |
+
What it computes:
|
| 191 |
+
- Runs the model multiple times with dropout enabled and computes a CAM per run.
|
| 192 |
+
- Returns the per-pixel mean and standard deviation across CAMs. High std indicates
|
| 193 |
+
the model's focus is unstable across stochastic perturbations.
|
| 194 |
+
|
| 195 |
+
Why this helps:
|
| 196 |
+
- If heatmaps vary a lot, the interpretability output is less reliable. Use this to flag
|
| 197 |
+
detections where explanations may not be trustworthy.
|
| 198 |
+
|
| 199 |
+
Practical tips:
|
| 200 |
+
- Increasing n_samples reduces variance in the estimate but increases runtime.
|
| 201 |
+
- Temporarily sets the model to train mode to activate dropout modules; restores eval mode.
|
| 202 |
+
"""
|
| 203 |
+
logger.info(f"Running MC Dropout uncertainty: samples={n_samples}, p={dropout_p}, detection={det_idx}")
|
| 204 |
+
try:
|
| 205 |
+
def enable_dropout(m):
|
| 206 |
+
if isinstance(m, torch.nn.Dropout):
|
| 207 |
+
m.train()
|
| 208 |
+
|
| 209 |
+
model.train()
|
| 210 |
+
model.apply(enable_dropout)
|
| 211 |
+
|
| 212 |
+
cams = []
|
| 213 |
+
conv_layer = model.model.backbone.conv_encoder.model.layer4[-1]
|
| 214 |
+
|
| 215 |
+
for i in range(n_samples):
|
| 216 |
+
outputs = model(pixel_values)
|
| 217 |
+
logits = outputs.logits
|
| 218 |
+
labels = logits.argmax(-1).squeeze(0)
|
| 219 |
+
label_id = labels[keep.nonzero()[det_idx]].item()
|
| 220 |
+
score = logits[0, keep.nonzero()[det_idx], label_id]
|
| 221 |
+
|
| 222 |
+
acts, grads = {}, {}
|
| 223 |
+
|
| 224 |
+
def fwd(m, i, o):
|
| 225 |
+
acts['v'] = o.detach()
|
| 226 |
+
|
| 227 |
+
def bwd(m, gi, go):
|
| 228 |
+
grads['v'] = go[0].detach()
|
| 229 |
+
|
| 230 |
+
h1 = conv_layer.register_forward_hook(fwd)
|
| 231 |
+
h2 = (conv_layer.register_full_backward_hook(bwd)
|
| 232 |
+
if hasattr(conv_layer, 'register_full_backward_hook')
|
| 233 |
+
else conv_layer.register_backward_hook(bwd))
|
| 234 |
+
|
| 235 |
+
model.zero_grad()
|
| 236 |
+
score.backward(retain_graph=False)
|
| 237 |
+
|
| 238 |
+
if 'v' not in acts:
|
| 239 |
+
logger.warning(f"No activations captured in sample {i}, using fallback zero map")
|
| 240 |
+
cam_resized = np.zeros((img.size[1], img.size[0]))
|
| 241 |
+
else:
|
| 242 |
+
act = acts['v'].squeeze(0)
|
| 243 |
+
grad = grads['v'].squeeze(0)
|
| 244 |
+
weights = grad.mean(dim=(1, 2))
|
| 245 |
+
cam = torch.relu((weights[:, None, None] * act).sum(0))
|
| 246 |
+
cam = cam / (cam.max() + 1e-8)
|
| 247 |
+
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy()
|
| 248 |
+
|
| 249 |
+
cams.append(cam_resized)
|
| 250 |
+
h1.remove(); h2.remove()
|
| 251 |
+
|
| 252 |
+
model.eval()
|
| 253 |
+
|
| 254 |
+
if len(cams) == 0:
|
| 255 |
+
logger.error("No valid CAM maps generated")
|
| 256 |
+
return np.zeros((img.size[1], img.size[0])), np.zeros((img.size[1], img.size[0]))
|
| 257 |
+
|
| 258 |
+
cams_arr = np.stack(cams, axis=0)
|
| 259 |
+
mean_map = cams_arr.mean(0)
|
| 260 |
+
std_map = cams_arr.std(0)
|
| 261 |
+
|
| 262 |
+
mean_map = (mean_map - mean_map.min()) / (mean_map.max() - mean_map.min() + 1e-8)
|
| 263 |
+
std_map = (std_map - std_map.min()) / (std_map.max() - std_map.min() + 1e-8)
|
| 264 |
+
|
| 265 |
+
logger.info("MC Dropout uncertainty completed")
|
| 266 |
+
return mean_map, std_map
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.error(f"Error in mc_dropout_uncertainty: {str(e)}", exc_info=True)
|
| 269 |
+
model.eval()
|
| 270 |
+
raise
|
| 271 |
+
|
| 272 |
+
# ---------- Full pipeline ----------
|
| 273 |
+
def interpret(img, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p):
|
| 274 |
+
logger.info(f"Starting interpretation - detection: {det_choice}, threshold: {conf_thresh}, cam: {cam_variant}, mc_samples: {mc_samples}, dropout_p: {dropout_p}")
|
| 275 |
+
try:
|
| 276 |
+
inputs = extractor(images=img, return_tensors="pt").to(device)
|
| 277 |
+
with torch.no_grad(): outputs = model(**inputs)
|
| 278 |
+
pixel_values_attr = inputs["pixel_values"].clone().requires_grad_(True)
|
| 279 |
+
target_sizes = [img.size[::-1]]
|
| 280 |
+
results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
|
| 281 |
+
keep = results["scores"] > conf_thresh
|
| 282 |
+
labels, scores = results["labels"][keep], results["scores"][keep]
|
| 283 |
+
|
| 284 |
+
logger.info(f"Found {len(labels)} detections above threshold {conf_thresh}")
|
| 285 |
+
|
| 286 |
+
if len(labels) == 0:
|
| 287 |
+
logger.warning("No detections found above threshold")
|
| 288 |
+
return None, "No detections above threshold", None, ""
|
| 289 |
+
|
| 290 |
+
if det_choice is None:
|
| 291 |
+
det_idx = 0
|
| 292 |
+
else:
|
| 293 |
+
try: det_idx = int(str(det_choice).split(":")[0])
|
| 294 |
+
except: det_idx = 0
|
| 295 |
+
|
| 296 |
+
label = model.config.id2label[labels[det_idx].item()]
|
| 297 |
+
logger.info(f"Selected detection {det_idx}: {label}")
|
| 298 |
+
|
| 299 |
+
# Grad-CAM / Grad-CAM++ (single deterministic pass)
|
| 300 |
+
cam = gradcam(img, det_idx, keep, pixel_values_attr, use_pp=(cam_variant=="Grad-CAM++"))
|
| 301 |
+
fig1, ax1 = plt.subplots(); ax1.imshow(img); ax1.imshow(cam, cmap="jet", alpha=0.5); ax1.axis("off")
|
| 302 |
+
ax1.set_title(f"{cam_variant}: {label}"); plt.close(fig1)
|
| 303 |
+
logger.debug(f"{cam_variant} visualization created")
|
| 304 |
+
|
| 305 |
+
# MC Dropout Uncertainty analysis
|
| 306 |
+
mean_map, std_map = mc_dropout_uncertainty(img, det_idx, keep, pixel_values_attr, n_samples=int(mc_samples), dropout_p=float(dropout_p))
|
| 307 |
+
# Create a composite figure: mean map and std map side-by-side
|
| 308 |
+
fig2, axes = plt.subplots(1,2, figsize=(8,4))
|
| 309 |
+
axes[0].imshow(img); axes[0].imshow(mean_map, cmap='hot', alpha=0.5); axes[0].axis('off'); axes[0].set_title('Predictive Mean')
|
| 310 |
+
axes[1].imshow(img); axes[1].imshow(std_map, cmap='viridis', alpha=0.5); axes[1].axis('off'); axes[1].set_title('Predictive Std (Uncertainty)')
|
| 311 |
+
plt.close(fig2)
|
| 312 |
+
logger.debug("MC Dropout uncertainty visualization created")
|
| 313 |
+
|
| 314 |
+
exp1 = f"π {cam_variant}:\nGradient-weighted feature maps β highlights where DETR focused."
|
| 315 |
+
exp2 = f"π MC Dropout Uncertainty:\nSamples={mc_samples}, dropout={dropout_p}. Shows predictive mean and per-pixel std as uncertainty."
|
| 316 |
+
|
| 317 |
+
logger.info("Interpretation completed successfully")
|
| 318 |
+
return fig1, exp1, fig2, exp2
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.error(f"Error in interpret function: {str(e)}", exc_info=True)
|
| 321 |
+
return None, f"Error: {str(e)}", None, ""
|
| 322 |
+
|
| 323 |
+
# ---------- Gradio UI ----------
|
| 324 |
+
with gr.Blocks() as demo:
|
| 325 |
+
gr.Markdown("## π§ DETR Interpretability Dashboard with Controls")
|
| 326 |
+
gr.Markdown(
|
| 327 |
+
"""
|
| 328 |
+
**How to use this dashboard**
|
| 329 |
+
|
| 330 |
+
- Upload an image using the left panel. The model will run object detection and list detected objects.
|
| 331 |
+
- Use the "Confidence Threshold" slider to filter detections by score. Detections below the threshold are hidden.
|
| 332 |
+
- Pick a detection from the dropdown to generate explanations for that object.
|
| 333 |
+
- Choose between `Grad-CAM` and `Grad-CAM++` (Grad-CAM++ often gives sharper, more localized maps).
|
| 334 |
+
- `MC Dropout Samples` controls how many stochastic forward passes are used to estimate prediction uncertainty. More samples give smoother estimates but take longer.
|
| 335 |
+
- `Dropout Probability` sets the dropout rate used during MC Dropout; higher values typically increase predicted uncertainty.
|
| 336 |
+
|
| 337 |
+
Tooltips are provided on each control (hover or focus) for quick hints.
|
| 338 |
+
"""
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
with gr.Row():
|
| 342 |
+
img_in = gr.Image(type="pil", label="Upload an image")
|
| 343 |
+
det_out = gr.Label(label="Detections")
|
| 344 |
+
det_fig = gr.Plot(label="Detections visualization")
|
| 345 |
+
det_choice = gr.Dropdown(label="Pick a detection for explanation")
|
| 346 |
+
|
| 347 |
+
with gr.Row():
|
| 348 |
+
conf_thresh = gr.Slider(0, 1, value=0.7, step=0.05, label="Confidence Threshold")
|
| 349 |
+
cam_variant = gr.Radio(["Grad-CAM", "Grad-CAM++"], value="Grad-CAM", label="Grad-CAM Variant")
|
| 350 |
+
mc_samples = gr.Slider(1, 100, value=20, step=1, label="MC Dropout Samples")
|
| 351 |
+
dropout_p = gr.Slider(0.0, 0.9, value=0.1, step=0.05, label="Dropout Probability")
|
| 352 |
+
|
| 353 |
+
btn = gr.Button("Explain")
|
| 354 |
+
|
| 355 |
+
gc_fig = gr.Plot(label="Grad-CAM / Grad-CAM++")
|
| 356 |
+
gc_txt = gr.Textbox(label="Explanation (Grad-CAM)")
|
| 357 |
+
unc_fig = gr.Plot(label="Uncertainty (MC Dropout)")
|
| 358 |
+
unc_txt = gr.Textbox(label="Explanation (Uncertainty)")
|
| 359 |
+
|
| 360 |
+
# Visible control tooltips section (for environments where hovering tooltips are not available)
|
| 361 |
+
gr.Markdown(
|
| 362 |
+
"""
|
| 363 |
+
**Control tooltips (quick reference)**
|
| 364 |
+
|
| 365 |
+
- Confidence Threshold: Filter out detections with confidence below this value.
|
| 366 |
+
- Grad-CAM Variant: Choose the gradient-based visualization method. Grad-CAM++ may highlight smaller regions more precisely.
|
| 367 |
+
- MC Dropout Samples: Number of stochastic forward passes for uncertainty estimation. Increase for more stable results.
|
| 368 |
+
- Dropout Probability: Dropout rate used during MC Dropout sampling. Higher values typically increase predictive variance.
|
| 369 |
+
- Pick a detection: Select which detected object to explain. Format shown as 'index: label (score)'.
|
| 370 |
+
"""
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# ---------- Key interpretability choices (Feynman-style) ----------
|
| 374 |
+
gr.Markdown(
|
| 375 |
+
"""
|
| 376 |
+
**Key interpretability choices & why they matter**
|
| 377 |
+
|
| 378 |
+
- **Baseline (Integrated Gradients)**: Defines what 'no signal' looks like. Black (zeros) is simple, but blurred or neutral baselines may give more meaningful attributions.
|
| 379 |
+
- **Which conv layer for Grad-CAM**: Early layers give fine texture but low semantics; very late layers are coarse. A late backbone conv (default used) is a good compromise.
|
| 380 |
+
- **Number of MC Dropout samples**: More samples = smoother, more stable uncertainty estimates, but higher compute cost.
|
| 381 |
+
- **Grad-CAM vs Grad-CAM++**: Grad-CAM++ can be sharper and better for overlapping instances; vanilla Grad-CAM is faster and simpler.
|
| 382 |
+
"""
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# ---------- Further reading / Feynman-style references ----------
|
| 386 |
+
# Add short, clickable references so users can read the original papers and deep-dive articles.
|
| 387 |
+
gr.Markdown(
|
| 388 |
+
"""
|
| 389 |
+
**Further reading (recommended)**
|
| 390 |
+
|
| 391 |
+
- [Grad-CAM β Selvaraju et al., 2017 (arXiv)](https://arxiv.org/abs/1610.02391) β the original Grad-CAM paper; explains the core idea of gradient-weighted localization.
|
| 392 |
+
- [Grad-CAM++ β Chattopadhay et al.](https://arxiv.org/abs/1710.11063) β an improved variant that often produces sharper maps and handles multiple instances better.
|
| 393 |
+
- [Visualizing the Impact of Feature Attribution Baselines (Distill)](https://distill.pub/2020/attribution-baselines) β an accessible deep dive on baseline choices for Integrated Gradients.
|
| 394 |
+
- [Captum docs β IntegratedGradients](https://captum.ai/api/integrated_gradients.html) β practical API notes for baseline, n_steps, and convergence delta.
|
| 395 |
+
- [Constructing sensible baselines for Integrated Gradients](https://arxiv.org/abs/2004.09627) β discussion and techniques for choosing baselines beyond a black image.
|
| 396 |
+
- [A New Baseline Assumption of Integrated Gradients Based on Shapley Values](https://arxiv.org/html/2310.04821v3) β recent research on improved baselines.
|
| 397 |
+
"""
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Helper: safe label getter in case model.config.id2label is missing or not a dict
|
| 401 |
+
def safe_label_lookup(idx):
|
| 402 |
+
try:
|
| 403 |
+
id2label = getattr(model.config, 'id2label', None)
|
| 404 |
+
if id2label is None:
|
| 405 |
+
return f"Class {idx}"
|
| 406 |
+
return id2label.get(int(idx), f"Class {idx}")
|
| 407 |
+
except Exception:
|
| 408 |
+
return f"Class {idx}"
|
| 409 |
+
|
| 410 |
+
def run_detect(img, conf_thresh):
|
| 411 |
+
logger.info(f"Running detection with confidence threshold: {conf_thresh}")
|
| 412 |
+
try:
|
| 413 |
+
inputs = extractor(images=img, return_tensors="pt").to(device)
|
| 414 |
+
with torch.no_grad(): outputs = model(**inputs)
|
| 415 |
+
target_sizes = [img.size[::-1]]
|
| 416 |
+
results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
|
| 417 |
+
keep = results["scores"] > conf_thresh
|
| 418 |
+
boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep]
|
| 419 |
+
|
| 420 |
+
logger.info(f"Detection found {len(labels)} objects above threshold")
|
| 421 |
+
|
| 422 |
+
det_list = [f"{i}: {safe_label_lookup(l.item())} ({s:.2f})" for i,(l,s) in enumerate(zip(labels,scores))]
|
| 423 |
+
fig, ax = plt.subplots(); ax.imshow(img); ax.axis("off")
|
| 424 |
+
for box,label,score in zip(boxes,labels,scores):
|
| 425 |
+
xmin,ymin,xmax,ymax = box
|
| 426 |
+
ax.add_patch(patches.Rectangle((xmin,ymin),xmax-xmin,ymax-ymin,fill=False,color="red",lw=2))
|
| 427 |
+
ax.text(xmin,ymin,f"{safe_label_lookup(label.item())}:{score:.2f}",color="black",
|
| 428 |
+
bbox=dict(facecolor="yellow",alpha=0.5))
|
| 429 |
+
plt.close(fig)
|
| 430 |
+
default_val = det_list[0] if len(det_list) > 0 else None
|
| 431 |
+
logger.debug("Detection visualization created")
|
| 432 |
+
return {det_out: str(det_list), det_fig: fig, det_choice: gr.update(choices=det_list, value=default_val)}
|
| 433 |
+
except Exception as e:
|
| 434 |
+
logger.error(f"Error in run_detect: {str(e)}", exc_info=True)
|
| 435 |
+
return {det_out: "Error in detection", det_fig: None, det_choice: gr.update(choices=[], value=None)}
|
| 436 |
+
|
| 437 |
+
img_in.change(run_detect, inputs=[img_in, conf_thresh], outputs=[det_out, det_fig, det_choice])
|
| 438 |
+
btn.click(interpret, inputs=[img_in, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p],
|
| 439 |
+
outputs=[gc_fig, gc_txt, unc_fig, unc_txt])
|
| 440 |
+
|
| 441 |
+
logger.info("Gradio interface configured, launching demo")
|
| 442 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
matplotlib
|
| 4 |
+
captum
|
| 5 |
+
ipython
|
| 6 |
+
transformers
|
| 7 |
+
pillow
|
| 8 |
+
lime
|
| 9 |
+
numpy
|
| 10 |
+
scikit-image
|
| 11 |
+
timm
|
| 12 |
+
streamlit
|
| 13 |
+
gradio
|
| 14 |
+
accelerate
|