Spaces:
Runtime error
Runtime error
| """Evaluates cross-modal correspondence of CLIP on PNG images.""" | |
| import os | |
| import sys | |
| from os.path import join, exists | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| from clip_grounding.utils.paths import REPO_PATH | |
| sys.path.append(join(REPO_PATH, "CLIP_explainability/Transformer-MM-Explainability/")) | |
| import torch | |
| import CLIP.clip as clip | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from captum.attr import visualization | |
| from torchmetrics import JaccardIndex | |
| from collections import defaultdict | |
| from IPython.core.display import display, HTML | |
| from skimage import filters | |
| from CLIP_explainability.utils import interpret, show_img_heatmap, show_txt_heatmap, color, _tokenizer | |
| from clip_grounding.datasets.png import PNG | |
| from clip_grounding.utils.image import pad_to_square | |
| from clip_grounding.utils.visualize import show_grid_of_images | |
| from clip_grounding.utils.log import tqdm_iterator, print_update | |
| # global usage | |
| # specify device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load CLIP model | |
| model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
| def show_cam(mask): | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) | |
| heatmap = np.float32(heatmap) / 255 | |
| cam = heatmap | |
| cam = cam / np.max(cam) | |
| return cam | |
| def interpret_and_generate(model, img, texts, orig_image, return_outputs=False, show=True): | |
| text = clip.tokenize(texts).to(device) | |
| R_text, R_image = interpret(model=model, image=img, texts=text, device=device) | |
| batch_size = text.shape[0] | |
| outputs = [] | |
| for i in range(batch_size): | |
| text_scores, text_tokens_decoded = show_txt_heatmap(texts[i], text[i], R_text[i], show=show) | |
| image_relevance = show_img_heatmap(R_image[i], img, orig_image=orig_image, device=device, show=show) | |
| plt.show() | |
| outputs.append({"text_scores": text_scores, "image_relevance": image_relevance, "tokens_decoded": text_tokens_decoded}) | |
| if return_outputs: | |
| return outputs | |
| def process_entry_text_to_image(entry, unimodal=False): | |
| image = entry['image'] | |
| text_mask = entry['text_mask'] | |
| text = entry['text'] | |
| orig_image = pad_to_square(image) | |
| img = preprocess(orig_image).unsqueeze(0).to(device) | |
| text_index = text_mask.argmax() | |
| texts = [text[text_index]] if not unimodal else [''] | |
| return img, texts, orig_image | |
| def preprocess_ground_truth_mask(mask, resize_shape): | |
| mask = Image.fromarray(mask.astype(np.uint8) * 255) | |
| mask = pad_to_square(mask, color=0) | |
| mask = mask.resize(resize_shape) | |
| mask = np.asarray(mask) / 255. | |
| return mask | |
| def apply_otsu_threshold(relevance_map): | |
| threshold = filters.threshold_otsu(relevance_map) | |
| otsu_map = (relevance_map > threshold).astype(np.uint8) | |
| return otsu_map | |
| def evaluate_text_to_image(method, dataset, debug=False): | |
| instance_level_metrics = defaultdict(list) | |
| entry_level_metrics = defaultdict(list) | |
| jaccard = JaccardIndex(num_classes=2) | |
| jaccard = jaccard.to(device) | |
| num_iter = len(dataset) | |
| if debug: | |
| num_iter = 100 | |
| iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset") | |
| for idx in iterator: | |
| instance = dataset[idx] | |
| instance_iou = 0. | |
| for entry in instance: | |
| # preprocess the image and text | |
| unimodal = True if method == "clip-unimodal" else False | |
| test_img, test_texts, orig_image = process_entry_text_to_image(entry, unimodal=unimodal) | |
| if method in ["clip", "clip-unimodal"]: | |
| # compute the relevance scores | |
| outputs = interpret_and_generate(model, test_img, test_texts, orig_image, return_outputs=True, show=False) | |
| # use the image relevance score to compute IoU w.r.t. ground truth segmentation masks | |
| # NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs | |
| relevance_map = outputs[0]["image_relevance"] | |
| elif method == "random": | |
| relevance_map = np.random.uniform(low=0., high=1., size=tuple(test_img.shape[2:])) | |
| otsu_relevance_map = apply_otsu_threshold(relevance_map) | |
| ground_truth_mask = entry["image_mask"] | |
| ground_truth_mask = preprocess_ground_truth_mask(ground_truth_mask, relevance_map.shape) | |
| entry_iou = jaccard( | |
| torch.from_numpy(otsu_relevance_map).to(device), | |
| torch.from_numpy(ground_truth_mask.astype(np.uint8)).to(device), | |
| ) | |
| entry_iou = entry_iou.item() | |
| instance_iou += (entry_iou / len(entry)) | |
| entry_level_metrics["iou"].append(entry_iou) | |
| # capture instance (image-sentence pair) level IoU | |
| instance_level_metrics["iou"].append(instance_iou) | |
| average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()} | |
| return ( | |
| average_metrics, | |
| instance_level_metrics, | |
| entry_level_metrics | |
| ) | |
| def process_entry_image_to_text(entry, unimodal=False): | |
| if not unimodal: | |
| if len(np.asarray(entry["image"]).shape) == 3: | |
| mask = np.repeat(np.expand_dims(entry['image_mask'], -1), 3, axis=-1) | |
| else: | |
| mask = np.asarray(entry['image_mask']) | |
| masked_image = (mask * np.asarray(entry['image'])).astype(np.uint8) | |
| masked_image = Image.fromarray(masked_image) | |
| orig_image = pad_to_square(masked_image) | |
| img = preprocess(orig_image).unsqueeze(0).to(device) | |
| else: | |
| orig_image_shape = max(np.asarray(entry['image']).shape[:2]) | |
| orig_image = Image.fromarray(np.zeros((orig_image_shape, orig_image_shape, 3), dtype=np.uint8)) | |
| # orig_image = Image.fromarray(np.random.randint(0, 256, (orig_image_shape, orig_image_shape, 3), dtype=np.uint8)) | |
| img = preprocess(orig_image).unsqueeze(0).to(device) | |
| texts = [' '.join(entry['text'])] | |
| return img, texts, orig_image | |
| def process_text_mask(text, text_mask, tokens): | |
| token_level_mask = np.zeros(len(tokens)) | |
| for label, subtext in zip(text_mask, text): | |
| subtext_tokens=_tokenizer.encode(subtext) | |
| subtext_tokens_decoded=[_tokenizer.decode([a]) for a in subtext_tokens] | |
| if label == 1: | |
| start = tokens.index(subtext_tokens_decoded[0]) | |
| end = tokens.index(subtext_tokens_decoded[-1]) | |
| token_level_mask[start:end + 1] = 1 | |
| return token_level_mask | |
| def evaluate_image_to_text(method, dataset, debug=False, clamp_sentence_len=70): | |
| instance_level_metrics = defaultdict(list) | |
| entry_level_metrics = defaultdict(list) | |
| # skipped if text length > 77 which is CLIP limit | |
| num_entries_skipped = 0 | |
| num_total_entries = 0 | |
| num_iter = len(dataset) | |
| if debug: | |
| num_iter = 100 | |
| jaccard_image_to_text = JaccardIndex(num_classes=2).to(device) | |
| iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset") | |
| for idx in iterator: | |
| instance = dataset[idx] | |
| instance_iou = 0. | |
| for entry in instance: | |
| num_total_entries += 1 | |
| # preprocess the image and text | |
| unimodal = True if method == "clip-unimodal" else False | |
| img, texts, orig_image = process_entry_image_to_text(entry, unimodal=unimodal) | |
| appx_total_sent_len = np.sum([len(x.split(" ")) for x in texts]) | |
| if appx_total_sent_len > clamp_sentence_len: | |
| # print(f"Skipping an entry since it's text has appx"\ | |
| # " {appx_total_sent_len} while CLIP cannot process beyond {clamp_sentence_len}") | |
| num_entries_skipped += 1 | |
| continue | |
| # compute the relevance scores | |
| if method in ["clip", "clip-unimodal"]: | |
| try: | |
| outputs = interpret_and_generate(model, img, texts, orig_image, return_outputs=True, show=False) | |
| except: | |
| num_entries_skipped += 1 | |
| continue | |
| elif method == "random": | |
| text = texts[0] | |
| text_tokens = _tokenizer.encode(text) | |
| text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens] | |
| outputs = [ | |
| { | |
| "text_scores": np.random.uniform(low=0., high=1., size=len(text_tokens_decoded)), | |
| "tokens_decoded": text_tokens_decoded, | |
| } | |
| ] | |
| # use the text relevance score to compute IoU w.r.t. ground truth text masks | |
| # NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs | |
| token_relevance_scores = outputs[0]["text_scores"] | |
| if isinstance(token_relevance_scores, torch.Tensor): | |
| token_relevance_scores = token_relevance_scores.cpu().numpy() | |
| token_relevance_scores = apply_otsu_threshold(token_relevance_scores) | |
| token_ground_truth_mask = process_text_mask(entry["text"], entry["text_mask"], outputs[0]["tokens_decoded"]) | |
| entry_iou = jaccard_image_to_text( | |
| torch.from_numpy(token_relevance_scores).to(device), | |
| torch.from_numpy(token_ground_truth_mask.astype(np.uint8)).to(device), | |
| ) | |
| entry_iou = entry_iou.item() | |
| instance_iou += (entry_iou / len(entry)) | |
| entry_level_metrics["iou"].append(entry_iou) | |
| # capture instance (image-sentence pair) level IoU | |
| instance_level_metrics["iou"].append(instance_iou) | |
| print(f"CAUTION: Skipped {(num_entries_skipped / num_total_entries) * 100} % since these had length > 77 (CLIP limit).") | |
| average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()} | |
| return ( | |
| average_metrics, | |
| instance_level_metrics, | |
| entry_level_metrics | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser("Evaluate Image-to-Text & Text-to-Image model") | |
| parser.add_argument( | |
| "--eval_method", type=str, default="clip", | |
| choices=["clip", "random", "clip-unimodal"], | |
| help="Evaluation method to use", | |
| ) | |
| parser.add_argument( | |
| "--ignore_cache", action="store_true", | |
| help="Ignore cache and force re-generation of the results", | |
| ) | |
| parser.add_argument( | |
| "--debug", action="store_true", | |
| help="Run evaluation on a small subset of the dataset", | |
| ) | |
| args = parser.parse_args() | |
| print_update("Using evaluation method: {}".format(args.eval_method)) | |
| clip.clip._MODELS = { | |
| "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", | |
| "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", | |
| } | |
| # specify device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load CLIP model | |
| print_update("Loading CLIP model...") | |
| model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
| print() | |
| # load PNG dataset | |
| print_update("Loading PNG dataset...") | |
| dataset = PNG(dataset_root=join(REPO_PATH, "data", "panoptic_narrative_grounding"), split="val2017") | |
| print() | |
| # evaluate | |
| # save metrics | |
| metrics_dir = join(REPO_PATH, "outputs") | |
| os.makedirs(metrics_dir, exist_ok=True) | |
| metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_text2image_metrics.pt") | |
| if (not exists(metrics_path)) or args.ignore_cache: | |
| print_update("Computing metrics for text-to-image grounding") | |
| average_metrics, instance_level_metrics, entry_level_metrics = evaluate_text_to_image( | |
| args.eval_method, dataset, debug=args.debug, | |
| ) | |
| metrics = { | |
| "average_metrics": average_metrics, | |
| "instance_level_metrics":instance_level_metrics, | |
| "entry_level_metrics": entry_level_metrics | |
| } | |
| torch.save(metrics, metrics_path) | |
| print("TEXT2IMAGE METRICS SAVED TO:", metrics_path) | |
| else: | |
| print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.") | |
| metrics = torch.load(metrics_path) | |
| average_metrics = metrics["average_metrics"] | |
| print("TEXT2IMAGE METRICS:", np.round(average_metrics["iou"], 4)) | |
| print() | |
| metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_image2text_metrics.pt") | |
| if (not exists(metrics_path)) or args.ignore_cache: | |
| print_update("Computing metrics for image-to-text grounding") | |
| average_metrics, instance_level_metrics, entry_level_metrics = evaluate_image_to_text( | |
| args.eval_method, dataset, debug=args.debug, | |
| ) | |
| torch.save( | |
| { | |
| "average_metrics": average_metrics, | |
| "instance_level_metrics":instance_level_metrics, | |
| "entry_level_metrics": entry_level_metrics | |
| }, | |
| metrics_path, | |
| ) | |
| print("IMAGE2TEXT METRICS SAVED TO:", metrics_path) | |
| else: | |
| print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.") | |
| metrics = torch.load(metrics_path) | |
| average_metrics = metrics["average_metrics"] | |
| print("IMAGE2TEXT METRICS:", np.round(average_metrics["iou"], 4)) | |