File size: 2,390 Bytes
6a6d12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn.functional as F

class SpectrogramGradCAM:
    """Minimal Grad-CAM for conv nets ending with global pooling + linear."""
    def __init__(self, model, target_layer_name: str = "features.6"):
        self.model = model
        self.gradients = None
        self.activations = None

        # Get the target conv layer by name (e.g., "features.6")
        target_layer = dict(model.named_modules())[target_layer_name]

        # Register hooks
        self._fh = target_layer.register_forward_hook(self._forward_hook)
        self._bh = target_layer.register_full_backward_hook(self._backward_hook)

    def _forward_hook(self, module, inp, out):
        # Keep on-device tensors; we’ll only read them
        self.activations = out

    def _backward_hook(self, module, grad_input, grad_output):
        # Gradient w.r.t. the layer output
        self.gradients = grad_output[0]

    def __call__(self, x, class_idx=None):
        """
        x: Tensor [B,1,H,W] with requires_grad=True.
        Returns:
          cam: np.ndarray [B,H,W] in [0,1]
          logits: np.ndarray [B,num_classes]
        """
        assert x.requires_grad, "Input to Grad-CAM must require grad"

        self.model.zero_grad(set_to_none=True)
        logits = self.model(x)  # forward WITH grads

        # Choose class per sample
        if class_idx is None:
            class_idx = logits.argmax(dim=1)  # [B]
        if isinstance(class_idx, int):
            selected = logits[:, class_idx]   # [B]
        else:
            selected = logits.gather(1, class_idx.view(-1, 1)).squeeze(1)  # [B]

        # Backprop from selected class score
        self.model.zero_grad(set_to_none=True)
        selected.sum().backward(retain_graph=True)

        # Build CAM from gradients & activations: [B,C,H,W]
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)   # GAP over H,W
        cam = (weights * self.activations).sum(dim=1, keepdim=True)  # [B,1,H,W]
        cam = F.relu(cam)

        # Normalize per-sample to [0,1]
        B, _, H, W = cam.shape
        cam = cam.view(B, H, W)
        cam_min = cam.view(B, -1).min(dim=1, keepdim=True)[0].view(B, 1, 1)
        cam_max = cam.view(B, -1).max(dim=1, keepdim=True)[0].view(B, 1, 1)
        cam = (cam - cam_min) / (cam_max - cam_min + 1e-6)

        return cam.detach().cpu().numpy(), logits.detach().cpu().numpy()