Spaces:
Sleeping
Sleeping
| ''' | |
| this is a combined script that implements DETR object detection with interpretability methods | |
| using Grad-CAM, Grad-CAM++, Integrated Gradients, and Monte Carlo Dropout for uncertainty estimation. | |
| It provides a Gradio-based web interface for users to upload images, select detected objects | |
| and visualize explanations and uncertainty maps. | |
| How to run it: | |
| ```python | |
| python detr_and_interp.py | |
| ``` | |
| ''' | |
| import torch, requests, numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from PIL import Image, ImageFilter | |
| import gradio as gr | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| from torchvision.transforms.functional import resize | |
| from captum.attr import IntegratedGradients | |
| import torch.nn.functional as F | |
| import logging | |
| import os | |
| from datetime import datetime | |
| # ---------- Logging Setup ---------- | |
| log_dir = os.path.join(os.path.dirname(__file__), "logs") | |
| os.makedirs(log_dir, exist_ok=True) | |
| log_file = os.path.join(log_dir, f"detr_interp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(log_file), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info("Starting DETR Interpretability Dashboard") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| model_name = "facebook/detr-resnet-50" | |
| logger.info(f"Loading model: {model_name}") | |
| model = DetrForObjectDetection.from_pretrained(model_name).to(device) | |
| extractor = DetrImageProcessor.from_pretrained(model_name) | |
| model.eval() | |
| logger.info("Model loaded and set to evaluation mode") | |
| # ---------- Grad-CAM / Grad-CAM++ ---------- | |
| def gradcam(img, det_idx, keep, pixel_values, use_pp=False): | |
| """ | |
| Compute Grad-CAM (or Grad-CAM++) heatmap for a selected detection. | |
| What it computes: | |
| - Captures feature-map activations from a late conv layer and the gradients of the | |
| detection score w.r.t. those activations. Channel-wise weights are computed from | |
| gradients and used to combine feature maps into a spatial heatmap. | |
| Why this matters: | |
| - Highlights which spatial regions the model used to make the prediction. Useful to | |
| check whether the detector is attending to the object vs irrelevant background. | |
| How to interpret results: | |
| - High values in the returned heatmap indicate regions that contributed positively to | |
| the detection score. Grad-CAM++ (use_pp=True) computes a refined weighting that often | |
| yields sharper, better-localized maps when multiple instances overlap. | |
| Caveats & tips: | |
| - Choosing a layer too early will give fine-grained but semantically weak maps; too late | |
| will be coarse. We pick a late backbone conv block (layer4[-1]) as a sensible default. | |
| - Hooks must be removed after use to avoid memory leaks; we do that below. | |
| References: | |
| - Selvaraju et al., Grad-CAM (2017): https://arxiv.org/abs/1610.02391 | |
| """ | |
| logger.info(f"Running {'Grad-CAM++' if use_pp else 'Grad-CAM'} for detection {det_idx}") | |
| try: | |
| # pick a late conv layer that still retains spatial info | |
| conv_layer = model.model.backbone.conv_encoder.model.layer4[-1] | |
| activations, gradients = {}, {} | |
| def fwd(m, i, o): | |
| activations["v"] = o.detach() | |
| def bwd(m, gi, go): | |
| gradients["v"] = go[0].detach() | |
| h1 = conv_layer.register_forward_hook(fwd) | |
| h2 = conv_layer.register_full_backward_hook(bwd) if hasattr(conv_layer, "register_full_backward_hook") else conv_layer.register_backward_hook(bwd) | |
| logger.debug("Hooks registered for Grad-CAM") | |
| outputs_for_attr = model(pixel_values) | |
| logits = outputs_for_attr.logits | |
| labels = logits.argmax(-1).squeeze(0) | |
| label_id = labels[keep.nonzero()[det_idx]].item() | |
| score = logits[0, keep.nonzero()[det_idx], label_id] | |
| logger.debug(f"Target label_id: {label_id}, score: {score.item():.4f}") | |
| model.zero_grad() | |
| score.backward() | |
| acts = activations["v"].squeeze(0) | |
| grads = gradients["v"].squeeze(0) | |
| logger.debug(f"Activations shape: {acts.shape}, Gradients shape: {grads.shape}") | |
| if use_pp: # Grad-CAM++ | |
| weights = (grads ** 2).mean(dim=(1, 2)) / (2 * (grads ** 2).mean(dim=(1, 2)) + (acts * grads ** 3).mean(dim=(1, 2)) + 1e-8) | |
| else: # vanilla Grad-CAM | |
| weights = grads.mean(dim=(1, 2)) | |
| cam = torch.relu((weights[:, None, None] * acts).sum(0)) | |
| cam = cam / (cam.max() + 1e-8) | |
| cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy() | |
| h1.remove(); h2.remove() | |
| logger.info(f"{'Grad-CAM++' if use_pp else 'Grad-CAM'} completed successfully") | |
| return cam_resized | |
| except Exception as e: | |
| logger.error(f"Error in gradcam: {str(e)}", exc_info=True) | |
| raise | |
| # ---------- Integrated Gradients ---------- | |
| def integrated_grad(img, det_idx, keep, outputs_for_attr, pixel_values, baseline="black"): | |
| """ | |
| Compute Integrated Gradients attribution map for a detection's logit. | |
| What it computes: | |
| - Integrates gradients along a path from a baseline input to the real input in embedding | |
| space, producing per-pixel (or per-channel) attributions. | |
| Why baseline choice matters: | |
| - The baseline defines what the model should consider as 'no signal'. Common choices: | |
| black (zeros), a blurred version of the image, or a neutral/mean image. Different | |
| baselines highlight different aspects of the input. | |
| How to read the output: | |
| - Values > 0 indicate pixels that increase the detection logit vs baseline; values < 0 | |
| reduce it. We normalize the result to [0,1] for visualization convenience. | |
| Tips: | |
| - Increase n_steps for smoother attributions (costlier). Check convergence_delta to | |
| validate IG's completeness property. | |
| References: | |
| - Distill article on baselines: https://distill.pub/2020/attribution-baselines | |
| - Captum IntegratedGradients docs: https://captum.ai/api/integrated_gradients.html | |
| """ | |
| logger.info(f"Running Integrated Gradients with {baseline} baseline for detection {det_idx}") | |
| try: | |
| logits = outputs_for_attr.logits | |
| labels = logits.argmax(-1).squeeze(0) | |
| label_id = labels[keep.nonzero()[det_idx]].item() | |
| logger.debug(f"IG target label_id: {label_id}") | |
| # Baselines | |
| if baseline == "black": | |
| base = torch.zeros_like(pixel_values) | |
| logger.debug("Using black baseline") | |
| elif baseline == "blur": | |
| blur = img.filter(ImageFilter.GaussianBlur(radius=15)) | |
| base = extractor(images=blur, return_tensors="pt")["pixel_values"].to(device) | |
| logger.debug("Using blurred baseline") | |
| else: | |
| base = torch.zeros_like(pixel_values) | |
| logger.debug("Defaulting to black baseline") | |
| def forward_func(pix): | |
| return model(pix).logits[:, keep.nonzero()[det_idx], label_id] | |
| ig = IntegratedGradients(forward_func) | |
| attr, _ = ig.attribute(pixel_values, baselines=base, n_steps=25, return_convergence_delta=True) | |
| arr = attr.squeeze().mean(0).cpu().detach().numpy() | |
| logger.info(f"Integrated Gradients with {baseline} baseline completed") | |
| return (arr - arr.min()) / (arr.max() - arr.min() + 1e-8) | |
| except Exception as e: | |
| logger.error(f"Error in integrated_grad: {str(e)}", exc_info=True) | |
| raise | |
| # ---------- Monte Carlo Dropout Uncertainty ---------- | |
| def mc_dropout_uncertainty(img, det_idx, keep, pixel_values, n_samples=20, dropout_p=0.1): | |
| """ | |
| Estimate uncertainty by running multiple stochastic forward passes with dropout active. | |
| What it computes: | |
| - Runs the model multiple times with dropout enabled and computes a CAM per run. | |
| - Returns the per-pixel mean and standard deviation across CAMs. High std indicates | |
| the model's focus is unstable across stochastic perturbations. | |
| Why this helps: | |
| - If heatmaps vary a lot, the interpretability output is less reliable. Use this to flag | |
| detections where explanations may not be trustworthy. | |
| Practical tips: | |
| - Increasing n_samples reduces variance in the estimate but increases runtime. | |
| - Temporarily sets the model to train mode to activate dropout modules; restores eval mode. | |
| """ | |
| logger.info(f"Running MC Dropout uncertainty: samples={n_samples}, p={dropout_p}, detection={det_idx}") | |
| try: | |
| def enable_dropout(m): | |
| if isinstance(m, torch.nn.Dropout): | |
| m.train() | |
| model.train() | |
| model.apply(enable_dropout) | |
| cams = [] | |
| conv_layer = model.model.backbone.conv_encoder.model.layer4[-1] | |
| for i in range(n_samples): | |
| outputs = model(pixel_values) | |
| logits = outputs.logits | |
| labels = logits.argmax(-1).squeeze(0) | |
| label_id = labels[keep.nonzero()[det_idx]].item() | |
| score = logits[0, keep.nonzero()[det_idx], label_id] | |
| acts, grads = {}, {} | |
| def fwd(m, i, o): | |
| acts['v'] = o.detach() | |
| def bwd(m, gi, go): | |
| grads['v'] = go[0].detach() | |
| h1 = conv_layer.register_forward_hook(fwd) | |
| h2 = (conv_layer.register_full_backward_hook(bwd) | |
| if hasattr(conv_layer, 'register_full_backward_hook') | |
| else conv_layer.register_backward_hook(bwd)) | |
| model.zero_grad() | |
| score.backward(retain_graph=False) | |
| if 'v' not in acts: | |
| logger.warning(f"No activations captured in sample {i}, using fallback zero map") | |
| cam_resized = np.zeros((img.size[1], img.size[0])) | |
| else: | |
| act = acts['v'].squeeze(0) | |
| grad = grads['v'].squeeze(0) | |
| weights = grad.mean(dim=(1, 2)) | |
| cam = torch.relu((weights[:, None, None] * act).sum(0)) | |
| cam = cam / (cam.max() + 1e-8) | |
| cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy() | |
| cams.append(cam_resized) | |
| h1.remove(); h2.remove() | |
| model.eval() | |
| if len(cams) == 0: | |
| logger.error("No valid CAM maps generated") | |
| return np.zeros((img.size[1], img.size[0])), np.zeros((img.size[1], img.size[0])) | |
| cams_arr = np.stack(cams, axis=0) | |
| mean_map = cams_arr.mean(0) | |
| std_map = cams_arr.std(0) | |
| mean_map = (mean_map - mean_map.min()) / (mean_map.max() - mean_map.min() + 1e-8) | |
| std_map = (std_map - std_map.min()) / (std_map.max() - std_map.min() + 1e-8) | |
| logger.info("MC Dropout uncertainty completed") | |
| return mean_map, std_map | |
| except Exception as e: | |
| logger.error(f"Error in mc_dropout_uncertainty: {str(e)}", exc_info=True) | |
| model.eval() | |
| raise | |
| # ---------- Full pipeline ---------- | |
| def interpret(img, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p): | |
| logger.info(f"Starting interpretation - detection: {det_choice}, threshold: {conf_thresh}, cam: {cam_variant}, mc_samples: {mc_samples}, dropout_p: {dropout_p}") | |
| try: | |
| inputs = extractor(images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): outputs = model(**inputs) | |
| pixel_values_attr = inputs["pixel_values"].clone().requires_grad_(True) | |
| target_sizes = [img.size[::-1]] | |
| results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0] | |
| keep = results["scores"] > conf_thresh | |
| labels, scores = results["labels"][keep], results["scores"][keep] | |
| logger.info(f"Found {len(labels)} detections above threshold {conf_thresh}") | |
| if len(labels) == 0: | |
| logger.warning("No detections found above threshold") | |
| return None, "No detections above threshold", None, "" | |
| if det_choice is None: | |
| det_idx = 0 | |
| else: | |
| try: det_idx = int(str(det_choice).split(":")[0]) | |
| except: det_idx = 0 | |
| label = model.config.id2label[labels[det_idx].item()] | |
| logger.info(f"Selected detection {det_idx}: {label}") | |
| # Grad-CAM / Grad-CAM++ (single deterministic pass) | |
| cam = gradcam(img, det_idx, keep, pixel_values_attr, use_pp=(cam_variant=="Grad-CAM++")) | |
| fig1, ax1 = plt.subplots(); ax1.imshow(img); ax1.imshow(cam, cmap="jet", alpha=0.5); ax1.axis("off") | |
| ax1.set_title(f"{cam_variant}: {label}"); plt.close(fig1) | |
| logger.debug(f"{cam_variant} visualization created") | |
| # MC Dropout Uncertainty analysis | |
| mean_map, std_map = mc_dropout_uncertainty(img, det_idx, keep, pixel_values_attr, n_samples=int(mc_samples), dropout_p=float(dropout_p)) | |
| # Create a composite figure: mean map and std map side-by-side | |
| fig2, axes = plt.subplots(1,2, figsize=(8,4)) | |
| axes[0].imshow(img); axes[0].imshow(mean_map, cmap='hot', alpha=0.5); axes[0].axis('off'); axes[0].set_title('Predictive Mean') | |
| axes[1].imshow(img); axes[1].imshow(std_map, cmap='viridis', alpha=0.5); axes[1].axis('off'); axes[1].set_title('Predictive Std (Uncertainty)') | |
| plt.close(fig2) | |
| logger.debug("MC Dropout uncertainty visualization created") | |
| exp1 = f"π {cam_variant}:\nGradient-weighted feature maps β highlights where DETR focused." | |
| exp2 = f"π MC Dropout Uncertainty:\nSamples={mc_samples}, dropout={dropout_p}. Shows predictive mean and per-pixel std as uncertainty." | |
| logger.info("Interpretation completed successfully") | |
| return fig1, exp1, fig2, exp2 | |
| except Exception as e: | |
| logger.error(f"Error in interpret function: {str(e)}", exc_info=True) | |
| return None, f"Error: {str(e)}", None, "" | |
| # ---------- Gradio UI ---------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π§ DETR Interpretability Dashboard with Controls") | |
| gr.Markdown( | |
| """ | |
| **How to use this dashboard** | |
| - Upload an image using the left panel. The model will run object detection and list detected objects. Try [imageNet](https://www.image-net.org/) | |
| - Use the "Confidence Threshold" slider to filter detections by score. Detections below the threshold are hidden. | |
| - Pick a detection from the dropdown to generate explanations for that object. | |
| - Choose between `Grad-CAM` and `Grad-CAM++` (Grad-CAM++ often gives sharper, more localized maps). | |
| - `MC Dropout Samples` controls how many stochastic forward passes are used to estimate prediction uncertainty. More samples give smoother estimates but take longer. | |
| - `Dropout Probability` sets the dropout rate used during MC Dropout; higher values typically increase predicted uncertainty. | |
| Tooltips are provided on each control (hover or focus) for quick hints. | |
| """ | |
| ) | |
| with gr.Row(): | |
| img_in = gr.Image(type="pil", label="Upload an image") | |
| det_out = gr.Label(label="Detections") | |
| det_fig = gr.Plot(label="Detections visualization") | |
| det_choice = gr.Dropdown(label="Pick a detection for explanation") | |
| with gr.Row(): | |
| conf_thresh = gr.Slider(0, 1, value=0.7, step=0.05, label="Confidence Threshold") | |
| cam_variant = gr.Radio(["Grad-CAM", "Grad-CAM++"], value="Grad-CAM", label="Grad-CAM Variant") | |
| mc_samples = gr.Slider(1, 100, value=20, step=1, label="MC Dropout Samples") | |
| dropout_p = gr.Slider(0.0, 0.9, value=0.1, step=0.05, label="Dropout Probability") | |
| btn = gr.Button("Explain") | |
| gc_fig = gr.Plot(label="Grad-CAM / Grad-CAM++") | |
| gc_txt = gr.Textbox(label="Explanation (Grad-CAM)") | |
| unc_fig = gr.Plot(label="Uncertainty (MC Dropout)") | |
| unc_txt = gr.Textbox(label="Explanation (Uncertainty)") | |
| # Visible control tooltips section (for environments where hovering tooltips are not available) | |
| gr.Markdown( | |
| """ | |
| **Control tooltips (quick reference)** | |
| - Confidence Threshold: Filter out detections with confidence below this value. | |
| - Grad-CAM Variant: Choose the gradient-based visualization method. Grad-CAM++ may highlight smaller regions more precisely. | |
| - MC Dropout Samples: Number of stochastic forward passes for uncertainty estimation. Increase for more stable results. | |
| - Dropout Probability: Dropout rate used during MC Dropout sampling. Higher values typically increase predictive variance. | |
| - Pick a detection: Select which detected object to explain. Format shown as 'index: label (score)'. | |
| """ | |
| ) | |
| # ---------- Key interpretability choices (Feynman-style) ---------- | |
| gr.Markdown( | |
| """ | |
| **Key interpretability choices & why they matter** | |
| - **Baseline (Integrated Gradients)**: Defines what 'no signal' looks like. Black (zeros) is simple, but blurred or neutral baselines may give more meaningful attributions. | |
| - **Which conv layer for Grad-CAM**: Early layers give fine texture but low semantics; very late layers are coarse. A late backbone conv (default used) is a good compromise. | |
| - **Number of MC Dropout samples**: More samples = smoother, more stable uncertainty estimates, but higher compute cost. | |
| - **Grad-CAM vs Grad-CAM++**: Grad-CAM++ can be sharper and better for overlapping instances; vanilla Grad-CAM is faster and simpler. | |
| """ | |
| ) | |
| # ---------- Further reading / Feynman-style references ---------- | |
| # Add short, clickable references so users can read the original papers and deep-dive articles. | |
| gr.Markdown( | |
| """ | |
| **Further reading (recommended)** | |
| - [Grad-CAM β Selvaraju et al., 2017 (arXiv)](https://arxiv.org/abs/1610.02391) β the original Grad-CAM paper; explains the core idea of gradient-weighted localization. | |
| - [Grad-CAM++ β Chattopadhay et al.](https://arxiv.org/abs/1710.11063) β an improved variant that often produces sharper maps and handles multiple instances better. | |
| - [Visualizing the Impact of Feature Attribution Baselines (Distill)](https://distill.pub/2020/attribution-baselines) β an accessible deep dive on baseline choices for Integrated Gradients. | |
| - [Captum docs β IntegratedGradients](https://captum.ai/api/integrated_gradients.html) β practical API notes for baseline, n_steps, and convergence delta. | |
| - [Constructing sensible baselines for Integrated Gradients](https://arxiv.org/abs/2004.09627) β discussion and techniques for choosing baselines beyond a black image. | |
| - [A New Baseline Assumption of Integrated Gradients Based on Shapley Values](https://arxiv.org/html/2310.04821v3) β recent research on improved baselines. | |
| """ | |
| ) | |
| # Helper: safe label getter in case model.config.id2label is missing or not a dict | |
| def safe_label_lookup(idx): | |
| try: | |
| id2label = getattr(model.config, 'id2label', None) | |
| if id2label is None: | |
| return f"Class {idx}" | |
| return id2label.get(int(idx), f"Class {idx}") | |
| except Exception: | |
| return f"Class {idx}" | |
| def run_detect(img, conf_thresh): | |
| logger.info(f"Running detection with confidence threshold: {conf_thresh}") | |
| try: | |
| inputs = extractor(images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): outputs = model(**inputs) | |
| target_sizes = [img.size[::-1]] | |
| results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0] | |
| keep = results["scores"] > conf_thresh | |
| boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep] | |
| logger.info(f"Detection found {len(labels)} objects above threshold") | |
| det_list = [f"{i}: {safe_label_lookup(l.item())} ({s:.2f})" for i,(l,s) in enumerate(zip(labels,scores))] | |
| fig, ax = plt.subplots(); ax.imshow(img); ax.axis("off") | |
| for box,label,score in zip(boxes,labels,scores): | |
| xmin,ymin,xmax,ymax = box | |
| ax.add_patch(patches.Rectangle((xmin,ymin),xmax-xmin,ymax-ymin,fill=False,color="red",lw=2)) | |
| ax.text(xmin,ymin,f"{safe_label_lookup(label.item())}:{score:.2f}",color="black", | |
| bbox=dict(facecolor="yellow",alpha=0.5)) | |
| plt.close(fig) | |
| default_val = det_list[0] if len(det_list) > 0 else None | |
| logger.debug("Detection visualization created") | |
| return {det_out: str(det_list), det_fig: fig, det_choice: gr.update(choices=det_list, value=default_val)} | |
| except Exception as e: | |
| logger.error(f"Error in run_detect: {str(e)}", exc_info=True) | |
| return {det_out: "Error in detection", det_fig: None, det_choice: gr.update(choices=[], value=None)} | |
| img_in.change(run_detect, inputs=[img_in, conf_thresh], outputs=[det_out, det_fig, det_choice]) | |
| btn.click(interpret, inputs=[img_in, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p], | |
| outputs=[gc_fig, gc_txt, unc_fig, unc_txt]) | |
| logger.info("Gradio interface configured, launching demo") | |
| demo.launch() | |