Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| class SpectrogramGradCAM: | |
| """Minimal Grad-CAM for conv nets ending with global pooling + linear.""" | |
| def __init__(self, model, target_layer_name: str = "features.6"): | |
| self.model = model | |
| self.gradients = None | |
| self.activations = None | |
| # Get the target conv layer by name (e.g., "features.6") | |
| target_layer = dict(model.named_modules())[target_layer_name] | |
| # Register hooks | |
| self._fh = target_layer.register_forward_hook(self._forward_hook) | |
| self._bh = target_layer.register_full_backward_hook(self._backward_hook) | |
| def _forward_hook(self, module, inp, out): | |
| # Keep on-device tensors; we’ll only read them | |
| self.activations = out | |
| def _backward_hook(self, module, grad_input, grad_output): | |
| # Gradient w.r.t. the layer output | |
| self.gradients = grad_output[0] | |
| def __call__(self, x, class_idx=None): | |
| """ | |
| x: Tensor [B,1,H,W] with requires_grad=True. | |
| Returns: | |
| cam: np.ndarray [B,H,W] in [0,1] | |
| logits: np.ndarray [B,num_classes] | |
| """ | |
| assert x.requires_grad, "Input to Grad-CAM must require grad" | |
| self.model.zero_grad(set_to_none=True) | |
| logits = self.model(x) # forward WITH grads | |
| # Choose class per sample | |
| if class_idx is None: | |
| class_idx = logits.argmax(dim=1) # [B] | |
| if isinstance(class_idx, int): | |
| selected = logits[:, class_idx] # [B] | |
| else: | |
| selected = logits.gather(1, class_idx.view(-1, 1)).squeeze(1) # [B] | |
| # Backprop from selected class score | |
| self.model.zero_grad(set_to_none=True) | |
| selected.sum().backward(retain_graph=True) | |
| # Build CAM from gradients & activations: [B,C,H,W] | |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) # GAP over H,W | |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) # [B,1,H,W] | |
| cam = F.relu(cam) | |
| # Normalize per-sample to [0,1] | |
| B, _, H, W = cam.shape | |
| cam = cam.view(B, H, W) | |
| cam_min = cam.view(B, -1).min(dim=1, keepdim=True)[0].view(B, 1, 1) | |
| cam_max = cam.view(B, -1).max(dim=1, keepdim=True)[0].view(B, 1, 1) | |
| cam = (cam - cam_min) / (cam_max - cam_min + 1e-6) | |
| return cam.detach().cpu().numpy(), logits.detach().cpu().numpy() |