Skier8402 commited on
Commit
ef2badf
Β·
verified Β·
1 Parent(s): 425ef18

Upload 2 files

Browse files
Files changed (2) hide show
  1. detr_and_interp.py +442 -0
  2. requirements.txt +14 -0
detr_and_interp.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ this is a combined script that implements DETR object detection with interpretability methods
3
+ using Grad-CAM, Grad-CAM++, Integrated Gradients, and Monte Carlo Dropout for uncertainty estimation.
4
+ It provides a Gradio-based web interface for users to upload images, select detected objects
5
+ and visualize explanations and uncertainty maps.
6
+
7
+ How to run it:
8
+
9
+ ```python
10
+ python detr_and_interp.py
11
+ ```
12
+
13
+ '''
14
+
15
+ import torch, requests, numpy as np
16
+ import matplotlib.pyplot as plt
17
+ import matplotlib.patches as patches
18
+ from PIL import Image, ImageFilter
19
+ import gradio as gr
20
+ from transformers import DetrImageProcessor, DetrForObjectDetection
21
+ from torchvision.transforms.functional import resize
22
+ from captum.attr import IntegratedGradients
23
+ import torch.nn.functional as F
24
+ import logging
25
+ import os
26
+ from datetime import datetime
27
+
28
+ # ---------- Logging Setup ----------
29
+ log_dir = os.path.join(os.path.dirname(__file__), "logs")
30
+ os.makedirs(log_dir, exist_ok=True)
31
+ log_file = os.path.join(log_dir, f"detr_interp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
32
+
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
36
+ handlers=[
37
+ logging.FileHandler(log_file),
38
+ logging.StreamHandler()
39
+ ]
40
+ )
41
+ logger = logging.getLogger(__name__)
42
+
43
+ logger.info("Starting DETR Interpretability Dashboard")
44
+
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ logger.info(f"Using device: {device}")
47
+
48
+ model_name = "facebook/detr-resnet-50"
49
+ logger.info(f"Loading model: {model_name}")
50
+ model = DetrForObjectDetection.from_pretrained(model_name).to(device)
51
+ extractor = DetrImageProcessor.from_pretrained(model_name)
52
+ model.eval()
53
+ logger.info("Model loaded and set to evaluation mode")
54
+
55
+ # ---------- Grad-CAM / Grad-CAM++ ----------
56
+ def gradcam(img, det_idx, keep, pixel_values, use_pp=False):
57
+ """
58
+ Compute Grad-CAM (or Grad-CAM++) heatmap for a selected detection.
59
+
60
+ What it computes:
61
+ - Captures feature-map activations from a late conv layer and the gradients of the
62
+ detection score w.r.t. those activations. Channel-wise weights are computed from
63
+ gradients and used to combine feature maps into a spatial heatmap.
64
+
65
+ Why this matters:
66
+ - Highlights which spatial regions the model used to make the prediction. Useful to
67
+ check whether the detector is attending to the object vs irrelevant background.
68
+
69
+ How to interpret results:
70
+ - High values in the returned heatmap indicate regions that contributed positively to
71
+ the detection score. Grad-CAM++ (use_pp=True) computes a refined weighting that often
72
+ yields sharper, better-localized maps when multiple instances overlap.
73
+
74
+ Caveats & tips:
75
+ - Choosing a layer too early will give fine-grained but semantically weak maps; too late
76
+ will be coarse. We pick a late backbone conv block (layer4[-1]) as a sensible default.
77
+ - Hooks must be removed after use to avoid memory leaks; we do that below.
78
+
79
+ References:
80
+ - Selvaraju et al., Grad-CAM (2017): https://arxiv.org/abs/1610.02391
81
+ """
82
+ logger.info(f"Running {'Grad-CAM++' if use_pp else 'Grad-CAM'} for detection {det_idx}")
83
+ try:
84
+ # pick a late conv layer that still retains spatial info
85
+ conv_layer = model.model.backbone.conv_encoder.model.layer4[-1]
86
+ activations, gradients = {}, {}
87
+
88
+ def fwd(m, i, o):
89
+ activations["v"] = o.detach()
90
+
91
+ def bwd(m, gi, go):
92
+ gradients["v"] = go[0].detach()
93
+
94
+ h1 = conv_layer.register_forward_hook(fwd)
95
+ h2 = conv_layer.register_full_backward_hook(bwd) if hasattr(conv_layer, "register_full_backward_hook") else conv_layer.register_backward_hook(bwd)
96
+ logger.debug("Hooks registered for Grad-CAM")
97
+
98
+ outputs_for_attr = model(pixel_values)
99
+ logits = outputs_for_attr.logits
100
+ labels = logits.argmax(-1).squeeze(0)
101
+ label_id = labels[keep.nonzero()[det_idx]].item()
102
+ score = logits[0, keep.nonzero()[det_idx], label_id]
103
+ logger.debug(f"Target label_id: {label_id}, score: {score.item():.4f}")
104
+
105
+ model.zero_grad()
106
+ score.backward()
107
+
108
+ acts = activations["v"].squeeze(0)
109
+ grads = gradients["v"].squeeze(0)
110
+ logger.debug(f"Activations shape: {acts.shape}, Gradients shape: {grads.shape}")
111
+
112
+ if use_pp: # Grad-CAM++
113
+ weights = (grads ** 2).mean(dim=(1, 2)) / (2 * (grads ** 2).mean(dim=(1, 2)) + (acts * grads ** 3).mean(dim=(1, 2)) + 1e-8)
114
+ else: # vanilla Grad-CAM
115
+ weights = grads.mean(dim=(1, 2))
116
+
117
+ cam = torch.relu((weights[:, None, None] * acts).sum(0))
118
+ cam = cam / (cam.max() + 1e-8)
119
+ cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy()
120
+
121
+ h1.remove(); h2.remove()
122
+ logger.info(f"{'Grad-CAM++' if use_pp else 'Grad-CAM'} completed successfully")
123
+ return cam_resized
124
+ except Exception as e:
125
+ logger.error(f"Error in gradcam: {str(e)}", exc_info=True)
126
+ raise
127
+
128
+ # ---------- Integrated Gradients ----------
129
+ def integrated_grad(img, det_idx, keep, outputs_for_attr, pixel_values, baseline="black"):
130
+ """
131
+ Compute Integrated Gradients attribution map for a detection's logit.
132
+
133
+ What it computes:
134
+ - Integrates gradients along a path from a baseline input to the real input in embedding
135
+ space, producing per-pixel (or per-channel) attributions.
136
+
137
+ Why baseline choice matters:
138
+ - The baseline defines what the model should consider as 'no signal'. Common choices:
139
+ black (zeros), a blurred version of the image, or a neutral/mean image. Different
140
+ baselines highlight different aspects of the input.
141
+
142
+ How to read the output:
143
+ - Values > 0 indicate pixels that increase the detection logit vs baseline; values < 0
144
+ reduce it. We normalize the result to [0,1] for visualization convenience.
145
+
146
+ Tips:
147
+ - Increase n_steps for smoother attributions (costlier). Check convergence_delta to
148
+ validate IG's completeness property.
149
+
150
+ References:
151
+ - Distill article on baselines: https://distill.pub/2020/attribution-baselines
152
+ - Captum IntegratedGradients docs: https://captum.ai/api/integrated_gradients.html
153
+ """
154
+ logger.info(f"Running Integrated Gradients with {baseline} baseline for detection {det_idx}")
155
+ try:
156
+ logits = outputs_for_attr.logits
157
+ labels = logits.argmax(-1).squeeze(0)
158
+ label_id = labels[keep.nonzero()[det_idx]].item()
159
+ logger.debug(f"IG target label_id: {label_id}")
160
+
161
+ # Baselines
162
+ if baseline == "black":
163
+ base = torch.zeros_like(pixel_values)
164
+ logger.debug("Using black baseline")
165
+ elif baseline == "blur":
166
+ blur = img.filter(ImageFilter.GaussianBlur(radius=15))
167
+ base = extractor(images=blur, return_tensors="pt")["pixel_values"].to(device)
168
+ logger.debug("Using blurred baseline")
169
+ else:
170
+ base = torch.zeros_like(pixel_values)
171
+ logger.debug("Defaulting to black baseline")
172
+
173
+ def forward_func(pix):
174
+ return model(pix).logits[:, keep.nonzero()[det_idx], label_id]
175
+
176
+ ig = IntegratedGradients(forward_func)
177
+ attr, _ = ig.attribute(pixel_values, baselines=base, n_steps=25, return_convergence_delta=True)
178
+ arr = attr.squeeze().mean(0).cpu().detach().numpy()
179
+ logger.info(f"Integrated Gradients with {baseline} baseline completed")
180
+ return (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
181
+ except Exception as e:
182
+ logger.error(f"Error in integrated_grad: {str(e)}", exc_info=True)
183
+ raise
184
+
185
+ # ---------- Monte Carlo Dropout Uncertainty ----------
186
+ def mc_dropout_uncertainty(img, det_idx, keep, pixel_values, n_samples=20, dropout_p=0.1):
187
+ """
188
+ Estimate uncertainty by running multiple stochastic forward passes with dropout active.
189
+
190
+ What it computes:
191
+ - Runs the model multiple times with dropout enabled and computes a CAM per run.
192
+ - Returns the per-pixel mean and standard deviation across CAMs. High std indicates
193
+ the model's focus is unstable across stochastic perturbations.
194
+
195
+ Why this helps:
196
+ - If heatmaps vary a lot, the interpretability output is less reliable. Use this to flag
197
+ detections where explanations may not be trustworthy.
198
+
199
+ Practical tips:
200
+ - Increasing n_samples reduces variance in the estimate but increases runtime.
201
+ - Temporarily sets the model to train mode to activate dropout modules; restores eval mode.
202
+ """
203
+ logger.info(f"Running MC Dropout uncertainty: samples={n_samples}, p={dropout_p}, detection={det_idx}")
204
+ try:
205
+ def enable_dropout(m):
206
+ if isinstance(m, torch.nn.Dropout):
207
+ m.train()
208
+
209
+ model.train()
210
+ model.apply(enable_dropout)
211
+
212
+ cams = []
213
+ conv_layer = model.model.backbone.conv_encoder.model.layer4[-1]
214
+
215
+ for i in range(n_samples):
216
+ outputs = model(pixel_values)
217
+ logits = outputs.logits
218
+ labels = logits.argmax(-1).squeeze(0)
219
+ label_id = labels[keep.nonzero()[det_idx]].item()
220
+ score = logits[0, keep.nonzero()[det_idx], label_id]
221
+
222
+ acts, grads = {}, {}
223
+
224
+ def fwd(m, i, o):
225
+ acts['v'] = o.detach()
226
+
227
+ def bwd(m, gi, go):
228
+ grads['v'] = go[0].detach()
229
+
230
+ h1 = conv_layer.register_forward_hook(fwd)
231
+ h2 = (conv_layer.register_full_backward_hook(bwd)
232
+ if hasattr(conv_layer, 'register_full_backward_hook')
233
+ else conv_layer.register_backward_hook(bwd))
234
+
235
+ model.zero_grad()
236
+ score.backward(retain_graph=False)
237
+
238
+ if 'v' not in acts:
239
+ logger.warning(f"No activations captured in sample {i}, using fallback zero map")
240
+ cam_resized = np.zeros((img.size[1], img.size[0]))
241
+ else:
242
+ act = acts['v'].squeeze(0)
243
+ grad = grads['v'].squeeze(0)
244
+ weights = grad.mean(dim=(1, 2))
245
+ cam = torch.relu((weights[:, None, None] * act).sum(0))
246
+ cam = cam / (cam.max() + 1e-8)
247
+ cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy()
248
+
249
+ cams.append(cam_resized)
250
+ h1.remove(); h2.remove()
251
+
252
+ model.eval()
253
+
254
+ if len(cams) == 0:
255
+ logger.error("No valid CAM maps generated")
256
+ return np.zeros((img.size[1], img.size[0])), np.zeros((img.size[1], img.size[0]))
257
+
258
+ cams_arr = np.stack(cams, axis=0)
259
+ mean_map = cams_arr.mean(0)
260
+ std_map = cams_arr.std(0)
261
+
262
+ mean_map = (mean_map - mean_map.min()) / (mean_map.max() - mean_map.min() + 1e-8)
263
+ std_map = (std_map - std_map.min()) / (std_map.max() - std_map.min() + 1e-8)
264
+
265
+ logger.info("MC Dropout uncertainty completed")
266
+ return mean_map, std_map
267
+ except Exception as e:
268
+ logger.error(f"Error in mc_dropout_uncertainty: {str(e)}", exc_info=True)
269
+ model.eval()
270
+ raise
271
+
272
+ # ---------- Full pipeline ----------
273
+ def interpret(img, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p):
274
+ logger.info(f"Starting interpretation - detection: {det_choice}, threshold: {conf_thresh}, cam: {cam_variant}, mc_samples: {mc_samples}, dropout_p: {dropout_p}")
275
+ try:
276
+ inputs = extractor(images=img, return_tensors="pt").to(device)
277
+ with torch.no_grad(): outputs = model(**inputs)
278
+ pixel_values_attr = inputs["pixel_values"].clone().requires_grad_(True)
279
+ target_sizes = [img.size[::-1]]
280
+ results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
281
+ keep = results["scores"] > conf_thresh
282
+ labels, scores = results["labels"][keep], results["scores"][keep]
283
+
284
+ logger.info(f"Found {len(labels)} detections above threshold {conf_thresh}")
285
+
286
+ if len(labels) == 0:
287
+ logger.warning("No detections found above threshold")
288
+ return None, "No detections above threshold", None, ""
289
+
290
+ if det_choice is None:
291
+ det_idx = 0
292
+ else:
293
+ try: det_idx = int(str(det_choice).split(":")[0])
294
+ except: det_idx = 0
295
+
296
+ label = model.config.id2label[labels[det_idx].item()]
297
+ logger.info(f"Selected detection {det_idx}: {label}")
298
+
299
+ # Grad-CAM / Grad-CAM++ (single deterministic pass)
300
+ cam = gradcam(img, det_idx, keep, pixel_values_attr, use_pp=(cam_variant=="Grad-CAM++"))
301
+ fig1, ax1 = plt.subplots(); ax1.imshow(img); ax1.imshow(cam, cmap="jet", alpha=0.5); ax1.axis("off")
302
+ ax1.set_title(f"{cam_variant}: {label}"); plt.close(fig1)
303
+ logger.debug(f"{cam_variant} visualization created")
304
+
305
+ # MC Dropout Uncertainty analysis
306
+ mean_map, std_map = mc_dropout_uncertainty(img, det_idx, keep, pixel_values_attr, n_samples=int(mc_samples), dropout_p=float(dropout_p))
307
+ # Create a composite figure: mean map and std map side-by-side
308
+ fig2, axes = plt.subplots(1,2, figsize=(8,4))
309
+ axes[0].imshow(img); axes[0].imshow(mean_map, cmap='hot', alpha=0.5); axes[0].axis('off'); axes[0].set_title('Predictive Mean')
310
+ axes[1].imshow(img); axes[1].imshow(std_map, cmap='viridis', alpha=0.5); axes[1].axis('off'); axes[1].set_title('Predictive Std (Uncertainty)')
311
+ plt.close(fig2)
312
+ logger.debug("MC Dropout uncertainty visualization created")
313
+
314
+ exp1 = f"πŸ”Ž {cam_variant}:\nGradient-weighted feature maps β†’ highlights where DETR focused."
315
+ exp2 = f"πŸ”Ž MC Dropout Uncertainty:\nSamples={mc_samples}, dropout={dropout_p}. Shows predictive mean and per-pixel std as uncertainty."
316
+
317
+ logger.info("Interpretation completed successfully")
318
+ return fig1, exp1, fig2, exp2
319
+ except Exception as e:
320
+ logger.error(f"Error in interpret function: {str(e)}", exc_info=True)
321
+ return None, f"Error: {str(e)}", None, ""
322
+
323
+ # ---------- Gradio UI ----------
324
+ with gr.Blocks() as demo:
325
+ gr.Markdown("## 🧠 DETR Interpretability Dashboard with Controls")
326
+ gr.Markdown(
327
+ """
328
+ **How to use this dashboard**
329
+
330
+ - Upload an image using the left panel. The model will run object detection and list detected objects.
331
+ - Use the "Confidence Threshold" slider to filter detections by score. Detections below the threshold are hidden.
332
+ - Pick a detection from the dropdown to generate explanations for that object.
333
+ - Choose between `Grad-CAM` and `Grad-CAM++` (Grad-CAM++ often gives sharper, more localized maps).
334
+ - `MC Dropout Samples` controls how many stochastic forward passes are used to estimate prediction uncertainty. More samples give smoother estimates but take longer.
335
+ - `Dropout Probability` sets the dropout rate used during MC Dropout; higher values typically increase predicted uncertainty.
336
+
337
+ Tooltips are provided on each control (hover or focus) for quick hints.
338
+ """
339
+ )
340
+
341
+ with gr.Row():
342
+ img_in = gr.Image(type="pil", label="Upload an image")
343
+ det_out = gr.Label(label="Detections")
344
+ det_fig = gr.Plot(label="Detections visualization")
345
+ det_choice = gr.Dropdown(label="Pick a detection for explanation")
346
+
347
+ with gr.Row():
348
+ conf_thresh = gr.Slider(0, 1, value=0.7, step=0.05, label="Confidence Threshold")
349
+ cam_variant = gr.Radio(["Grad-CAM", "Grad-CAM++"], value="Grad-CAM", label="Grad-CAM Variant")
350
+ mc_samples = gr.Slider(1, 100, value=20, step=1, label="MC Dropout Samples")
351
+ dropout_p = gr.Slider(0.0, 0.9, value=0.1, step=0.05, label="Dropout Probability")
352
+
353
+ btn = gr.Button("Explain")
354
+
355
+ gc_fig = gr.Plot(label="Grad-CAM / Grad-CAM++")
356
+ gc_txt = gr.Textbox(label="Explanation (Grad-CAM)")
357
+ unc_fig = gr.Plot(label="Uncertainty (MC Dropout)")
358
+ unc_txt = gr.Textbox(label="Explanation (Uncertainty)")
359
+
360
+ # Visible control tooltips section (for environments where hovering tooltips are not available)
361
+ gr.Markdown(
362
+ """
363
+ **Control tooltips (quick reference)**
364
+
365
+ - Confidence Threshold: Filter out detections with confidence below this value.
366
+ - Grad-CAM Variant: Choose the gradient-based visualization method. Grad-CAM++ may highlight smaller regions more precisely.
367
+ - MC Dropout Samples: Number of stochastic forward passes for uncertainty estimation. Increase for more stable results.
368
+ - Dropout Probability: Dropout rate used during MC Dropout sampling. Higher values typically increase predictive variance.
369
+ - Pick a detection: Select which detected object to explain. Format shown as 'index: label (score)'.
370
+ """
371
+ )
372
+
373
+ # ---------- Key interpretability choices (Feynman-style) ----------
374
+ gr.Markdown(
375
+ """
376
+ **Key interpretability choices & why they matter**
377
+
378
+ - **Baseline (Integrated Gradients)**: Defines what 'no signal' looks like. Black (zeros) is simple, but blurred or neutral baselines may give more meaningful attributions.
379
+ - **Which conv layer for Grad-CAM**: Early layers give fine texture but low semantics; very late layers are coarse. A late backbone conv (default used) is a good compromise.
380
+ - **Number of MC Dropout samples**: More samples = smoother, more stable uncertainty estimates, but higher compute cost.
381
+ - **Grad-CAM vs Grad-CAM++**: Grad-CAM++ can be sharper and better for overlapping instances; vanilla Grad-CAM is faster and simpler.
382
+ """
383
+ )
384
+
385
+ # ---------- Further reading / Feynman-style references ----------
386
+ # Add short, clickable references so users can read the original papers and deep-dive articles.
387
+ gr.Markdown(
388
+ """
389
+ **Further reading (recommended)**
390
+
391
+ - [Grad-CAM β€” Selvaraju et al., 2017 (arXiv)](https://arxiv.org/abs/1610.02391) β€” the original Grad-CAM paper; explains the core idea of gradient-weighted localization.
392
+ - [Grad-CAM++ β€” Chattopadhay et al.](https://arxiv.org/abs/1710.11063) β€” an improved variant that often produces sharper maps and handles multiple instances better.
393
+ - [Visualizing the Impact of Feature Attribution Baselines (Distill)](https://distill.pub/2020/attribution-baselines) β€” an accessible deep dive on baseline choices for Integrated Gradients.
394
+ - [Captum docs β€” IntegratedGradients](https://captum.ai/api/integrated_gradients.html) β€” practical API notes for baseline, n_steps, and convergence delta.
395
+ - [Constructing sensible baselines for Integrated Gradients](https://arxiv.org/abs/2004.09627) β€” discussion and techniques for choosing baselines beyond a black image.
396
+ - [A New Baseline Assumption of Integrated Gradients Based on Shapley Values](https://arxiv.org/html/2310.04821v3) β€” recent research on improved baselines.
397
+ """
398
+ )
399
+
400
+ # Helper: safe label getter in case model.config.id2label is missing or not a dict
401
+ def safe_label_lookup(idx):
402
+ try:
403
+ id2label = getattr(model.config, 'id2label', None)
404
+ if id2label is None:
405
+ return f"Class {idx}"
406
+ return id2label.get(int(idx), f"Class {idx}")
407
+ except Exception:
408
+ return f"Class {idx}"
409
+
410
+ def run_detect(img, conf_thresh):
411
+ logger.info(f"Running detection with confidence threshold: {conf_thresh}")
412
+ try:
413
+ inputs = extractor(images=img, return_tensors="pt").to(device)
414
+ with torch.no_grad(): outputs = model(**inputs)
415
+ target_sizes = [img.size[::-1]]
416
+ results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
417
+ keep = results["scores"] > conf_thresh
418
+ boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep]
419
+
420
+ logger.info(f"Detection found {len(labels)} objects above threshold")
421
+
422
+ det_list = [f"{i}: {safe_label_lookup(l.item())} ({s:.2f})" for i,(l,s) in enumerate(zip(labels,scores))]
423
+ fig, ax = plt.subplots(); ax.imshow(img); ax.axis("off")
424
+ for box,label,score in zip(boxes,labels,scores):
425
+ xmin,ymin,xmax,ymax = box
426
+ ax.add_patch(patches.Rectangle((xmin,ymin),xmax-xmin,ymax-ymin,fill=False,color="red",lw=2))
427
+ ax.text(xmin,ymin,f"{safe_label_lookup(label.item())}:{score:.2f}",color="black",
428
+ bbox=dict(facecolor="yellow",alpha=0.5))
429
+ plt.close(fig)
430
+ default_val = det_list[0] if len(det_list) > 0 else None
431
+ logger.debug("Detection visualization created")
432
+ return {det_out: str(det_list), det_fig: fig, det_choice: gr.update(choices=det_list, value=default_val)}
433
+ except Exception as e:
434
+ logger.error(f"Error in run_detect: {str(e)}", exc_info=True)
435
+ return {det_out: "Error in detection", det_fig: None, det_choice: gr.update(choices=[], value=None)}
436
+
437
+ img_in.change(run_detect, inputs=[img_in, conf_thresh], outputs=[det_out, det_fig, det_choice])
438
+ btn.click(interpret, inputs=[img_in, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p],
439
+ outputs=[gc_fig, gc_txt, unc_fig, unc_txt])
440
+
441
+ logger.info("Gradio interface configured, launching demo")
442
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ matplotlib
4
+ captum
5
+ ipython
6
+ transformers
7
+ pillow
8
+ lime
9
+ numpy
10
+ scikit-image
11
+ timm
12
+ streamlit
13
+ gradio
14
+ accelerate