Spaces:
Runtime error
Runtime error
| import torch | |
| import CLIP.clip as clip | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from captum.attr import visualization | |
| import os | |
| from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer | |
| _tokenizer = _Tokenizer() | |
| #@title Control context expansion (number of attention layers to consider) | |
| #@title Number of layers for image Transformer | |
| start_layer = 11#@param {type:"number"} | |
| #@title Number of layers for text Transformer | |
| start_layer_text = 11#@param {type:"number"} | |
| def interpret(image, texts, model, device): | |
| batch_size = texts.shape[0] | |
| images = image.repeat(batch_size, 1, 1, 1) | |
| logits_per_image, logits_per_text = model(images, texts) | |
| probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy() | |
| index = [i for i in range(batch_size)] | |
| one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32) | |
| one_hot[torch.arange(logits_per_image.shape[0]), index] = 1 | |
| one_hot = torch.from_numpy(one_hot).requires_grad_(True) | |
| one_hot = torch.sum(one_hot.to(device) * logits_per_image) | |
| model.zero_grad() | |
| image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values()) | |
| num_tokens = image_attn_blocks[0].attn_probs.shape[-1] | |
| R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device) | |
| R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens) | |
| for i, blk in enumerate(image_attn_blocks): | |
| if i < start_layer: | |
| continue | |
| grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach() | |
| cam = blk.attn_probs.detach() | |
| cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1]) | |
| grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1]) | |
| cam = grad * cam | |
| cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1]) | |
| cam = cam.clamp(min=0).mean(dim=1) | |
| R = R + torch.bmm(cam, R) | |
| image_relevance = R[:, 0, 1:] | |
| text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values()) | |
| num_tokens = text_attn_blocks[0].attn_probs.shape[-1] | |
| R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device) | |
| R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens) | |
| for i, blk in enumerate(text_attn_blocks): | |
| if i < start_layer_text: | |
| continue | |
| grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach() | |
| cam = blk.attn_probs.detach() | |
| cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1]) | |
| grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1]) | |
| cam = grad * cam | |
| cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1]) | |
| cam = cam.clamp(min=0).mean(dim=1) | |
| R_text = R_text + torch.bmm(cam, R_text) | |
| text_relevance = R_text | |
| return text_relevance, image_relevance | |
| def show_image_relevance(image_relevance, image, orig_image, device, show=True): | |
| # create heatmap from mask on image | |
| def show_cam_on_image(img, mask): | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) | |
| heatmap = np.float32(heatmap) / 255 | |
| cam = heatmap + np.float32(img) | |
| cam = cam / np.max(cam) | |
| return cam | |
| # plt.axis('off') | |
| # f, axarr = plt.subplots(1,2) | |
| # axarr[0].imshow(orig_image) | |
| if show: | |
| fig, axs = plt.subplots(1, 2) | |
| axs[0].imshow(orig_image); | |
| axs[0].axis('off'); | |
| image_relevance = image_relevance.reshape(1, 1, 7, 7) | |
| image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear') | |
| image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy() | |
| image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) | |
| image = image[0].permute(1, 2, 0).data.cpu().numpy() | |
| image = (image - image.min()) / (image.max() - image.min()) | |
| vis = show_cam_on_image(image, image_relevance) | |
| vis = np.uint8(255 * vis) | |
| vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) | |
| if show: | |
| # axar[1].imshow(vis) | |
| axs[1].imshow(vis); | |
| axs[1].axis('off'); | |
| # plt.imshow(vis) | |
| return image_relevance | |
| def show_heatmap_on_text(text, text_encoding, R_text, show=True): | |
| CLS_idx = text_encoding.argmax(dim=-1) | |
| R_text = R_text[CLS_idx, 1:CLS_idx] | |
| text_scores = R_text / R_text.sum() | |
| text_scores = text_scores.flatten() | |
| # print(text_scores) | |
| text_tokens=_tokenizer.encode(text) | |
| text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens] | |
| vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)] | |
| if show: | |
| visualization.visualize_text(vis_data_records) | |
| return text_scores, text_tokens_decoded | |
| def show_img_heatmap(image_relevance, image, orig_image, device, show=True): | |
| return show_image_relevance(image_relevance, image, orig_image, device, show=show) | |
| def show_txt_heatmap(text, text_encoding, R_text, show=True): | |
| return show_heatmap_on_text(text, text_encoding, R_text, show=show) | |
| def load_dataset(): | |
| dataset_path = os.path.join('..', '..', 'dummy-data', '71226_segments' + '.pt') | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| data = torch.load(dataset_path, map_location=device) | |
| return data | |
| class color: | |
| PURPLE = '\033[95m' | |
| CYAN = '\033[96m' | |
| DARKCYAN = '\033[36m' | |
| BLUE = '\033[94m' | |
| GREEN = '\033[92m' | |
| YELLOW = '\033[93m' | |
| RED = '\033[91m' | |
| BOLD = '\033[1m' | |
| UNDERLINE = '\033[4m' | |
| END = '\033[0m' |