Spaces:
Sleeping
Sleeping
File size: 2,390 Bytes
6a6d12b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import torch.nn.functional as F
class SpectrogramGradCAM:
"""Minimal Grad-CAM for conv nets ending with global pooling + linear."""
def __init__(self, model, target_layer_name: str = "features.6"):
self.model = model
self.gradients = None
self.activations = None
# Get the target conv layer by name (e.g., "features.6")
target_layer = dict(model.named_modules())[target_layer_name]
# Register hooks
self._fh = target_layer.register_forward_hook(self._forward_hook)
self._bh = target_layer.register_full_backward_hook(self._backward_hook)
def _forward_hook(self, module, inp, out):
# Keep on-device tensors; we’ll only read them
self.activations = out
def _backward_hook(self, module, grad_input, grad_output):
# Gradient w.r.t. the layer output
self.gradients = grad_output[0]
def __call__(self, x, class_idx=None):
"""
x: Tensor [B,1,H,W] with requires_grad=True.
Returns:
cam: np.ndarray [B,H,W] in [0,1]
logits: np.ndarray [B,num_classes]
"""
assert x.requires_grad, "Input to Grad-CAM must require grad"
self.model.zero_grad(set_to_none=True)
logits = self.model(x) # forward WITH grads
# Choose class per sample
if class_idx is None:
class_idx = logits.argmax(dim=1) # [B]
if isinstance(class_idx, int):
selected = logits[:, class_idx] # [B]
else:
selected = logits.gather(1, class_idx.view(-1, 1)).squeeze(1) # [B]
# Backprop from selected class score
self.model.zero_grad(set_to_none=True)
selected.sum().backward(retain_graph=True)
# Build CAM from gradients & activations: [B,C,H,W]
weights = self.gradients.mean(dim=(2, 3), keepdim=True) # GAP over H,W
cam = (weights * self.activations).sum(dim=1, keepdim=True) # [B,1,H,W]
cam = F.relu(cam)
# Normalize per-sample to [0,1]
B, _, H, W = cam.shape
cam = cam.view(B, H, W)
cam_min = cam.view(B, -1).min(dim=1, keepdim=True)[0].view(B, 1, 1)
cam_max = cam.view(B, -1).max(dim=1, keepdim=True)[0].view(B, 1, 1)
cam = (cam - cam_min) / (cam_max - cam_min + 1e-6)
return cam.detach().cpu().numpy(), logits.detach().cpu().numpy() |