File size: 5,466 Bytes
ff7112c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
'''
Usage of DETR with Captum for interpretability.

Demonstrates Grad-CAM and Integrated Gradients on object detection.

On random COCO image, picks a detection and visualizes attributions. 
Appeals to developers and ML practitioners interested in model interpretability.

'''

import torch, requests, numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
from torchvision.transforms.functional import resize
from captum.attr import IntegratedGradients

# ---------------- 1. Load DETR ----------------
model_name = "facebook/detr-resnet-50"
model = DetrForObjectDetection.from_pretrained(model_name)
feature_extractor = DetrImageProcessor.from_pretrained(model_name)
model.eval()

# ---------------- 2. Load an image ----------------
url = "http://images.cocodataset.org/val2017/000000039769.jpg"  # dog+cat
img = Image.open(requests.get(url, stream=True).raw).convert("RGB")

# ---------------- 3. Preprocess & forward ----------------
inputs = feature_extractor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"]
outputs = model(pixel_values)

target_sizes = torch.tensor([img.size[::-1]])
# use the updated post_process_object_detection API
results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]

# ---------------- 4. Pick detection ----------------
keep = results["scores"] > 0.7
boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep]
chosen_idx = 0
chosen_label = labels[chosen_idx].item()
chosen_name = model.config.id2label[chosen_label]
score_val = float(scores[chosen_idx].detach().cpu().item()) if isinstance(scores[chosen_idx], torch.Tensor) else float(scores[chosen_idx])
print(f"Chosen detection: {chosen_name}, score={score_val:.2f}")

# ---------------- 5. Grad-CAM ----------------
# Find a suitable convolutional layer in the backbone (robust to implementation details)
backbone = getattr(model.model, "backbone", None)
conv_layer = None
if backbone is not None:
   for name, module in reversed(list(backbone.named_modules())):
       if isinstance(module, torch.nn.Conv2d):
           conv_layer = module
           conv_name = name
           break
# fallback to searching entire model
if conv_layer is None:
   for name, module in reversed(list(model.named_modules())):
       if isinstance(module, torch.nn.Conv2d):
           conv_layer = module
           conv_name = name
           break
if conv_layer is None:
   raise RuntimeError("No Conv2d layer found for Grad-CAM")

activations, gradients = {}, {}
def forward_hook(m, i, o): activations["value"] = o.detach()
# register_full_backward_hook is preferred where available
if hasattr(conv_layer, "register_full_backward_hook"):
   conv_layer.register_forward_hook(forward_hook)
   conv_layer.register_full_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()}))
else:
   conv_layer.register_forward_hook(forward_hook)
   conv_layer.register_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()}))

# Previously we computed outputs before registering hooks, so hooks didn't capture activations.
# Re-run a forward pass with inputs that require gradients, then backprop on the chosen detection logit.
# determine the query index corresponding to the chosen kept detection (from earlier results)
keep_idxs = torch.nonzero(keep).squeeze()
if keep_idxs.dim() == 0:
   chosen_query_idx = int(keep_idxs.item())
else:
   chosen_query_idx = int(keep_idxs[chosen_idx].item())

# prepare pixel_values for gradient computation and re-run forward to trigger hooks
pixel_values_for_grad = pixel_values.clone().detach().requires_grad_(True)
outputs_for_grad = model(pixel_values_for_grad)

# select the logit for that query & class and backpropagate
score_for_grad = outputs_for_grad.logits[0, chosen_query_idx, chosen_label]
model.zero_grad()
score_for_grad.backward()

# now activations and gradients should be populated by the hooks
acts = activations["value"].squeeze(0)  # (C,H,W)
grads = gradients["value"].squeeze(0)
weights = grads.mean(dim=(1,2))
cam = torch.relu((weights[:,None,None] * acts).sum(0))
cam = cam / cam.max()
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0,0].numpy()

# ---------------- 6. Integrated Gradients ----------------
# pick the chosen query index (as above) and create a forward function that returns a scalar logit per input
def forward_func(pixel_values):
    out = model(pixel_values=pixel_values)
    # return the selected query/class logit as a 1-D tensor (batch,)
    return out.logits[:, chosen_query_idx, chosen_label]

ig = IntegratedGradients(forward_func)
# since forward_func already returns a scalar logit per sample, don't pass target
attributions, _ = ig.attribute(pixel_values, n_steps=25, return_convergence_delta=True)

attr = attributions.squeeze().mean(0).cpu().detach().numpy()
attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8)

# ---------------- 7. Visualize ----------------
fig, axs = plt.subplots(1,3, figsize=(16,6))
axs[0].imshow(img); axs[0].set_title(f"Original: {chosen_name}"); axs[0].axis("off")
axs[1].imshow(img); axs[1].imshow(cam_resized, cmap="jet", alpha=0.5)
axs[1].set_title("Grad-CAM heatmap"); axs[1].axis("off")
axs[2].imshow(img); axs[2].imshow(attr, cmap="hot", alpha=0.5)
axs[2].set_title("Integrated Gradients"); axs[2].axis("off")
plt.show()