Spaces:
Sleeping
Sleeping
| from transformers import AutoFeatureExtractor, AutoModel | |
| import torch | |
| from torchvision.transforms.functional import to_pil_image | |
| from einops import rearrange, reduce | |
| from skops import hub_utils | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import gradio as gr | |
| import os | |
| import glob | |
| import pickle | |
| setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT'] | |
| embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16'] | |
| gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino'] | |
| embedder_to_setup = dict(zip(embedder_names, setups)) | |
| gam_to_setup = dict(zip(gam_names, setups)) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| embedders = {} | |
| for name in embedder_names: | |
| embedder = {} | |
| embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name) | |
| embedder['model'] = AutoModel.from_pretrained(name).eval().to(device) | |
| if 'resnet-50' in name: | |
| embedder['num_patches_side'] = 7 | |
| embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d') | |
| else: | |
| embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size | |
| embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:] | |
| embedders[embedder_to_setup[name]] = embedder | |
| gams = {} | |
| for name in gam_names: | |
| if not os.path.exists(name): | |
| os.mkdir(name) | |
| hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name) | |
| with open(f'{name}/model.pkl', 'rb') as infile: | |
| gams[gam_to_setup[name]] = pickle.load(infile) | |
| labels = [ | |
| 'tench', | |
| 'English springer', | |
| 'cassette player', | |
| 'chain saw', | |
| 'church', | |
| 'French horn', | |
| 'garbage truck', | |
| 'gas pump', | |
| 'golf ball', | |
| 'parachute' | |
| ] | |
| def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars): | |
| '''Visualizes the patch contributions to all labels of one or more visual | |
| Emb-GAMs''' | |
| if not visual_emb_gam_setups: | |
| fig = plt.Figure() | |
| return fig, fig | |
| patch_contributions = {} | |
| # get patch contributions per Emb-GAM | |
| for setup in visual_emb_gam_setups: | |
| # prepare embedding model | |
| embedder_setup = embedders[setup] | |
| feature_extractor = embedder_setup['feature_extractor'] | |
| embedding_postprocess = embedder_setup['embedding_postprocess'] | |
| num_patches_side = embedder_setup['num_patches_side'] | |
| # prepare GAM | |
| gam = gams[setup] | |
| # get patch embeddings | |
| inputs = { | |
| k: v.to(device) | |
| for k, v | |
| in feature_extractor(input_img, return_tensors='pt').items() | |
| } | |
| with torch.no_grad(): | |
| patch_embeddings = embedding_postprocess( | |
| embedder_setup['model'](**inputs) | |
| ).cpu()[0] | |
| # get patch emebddings | |
| patch_contributions[setup] = ( | |
| gam.coef_ \ | |
| + gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2) | |
| ).reshape(-1, num_patches_side, num_patches_side) | |
| # plot heatmaps | |
| multiple_setups = len(visual_emb_gam_setups) > 1 | |
| # set up figure | |
| fig, axs = plt.subplots( | |
| len(visual_emb_gam_setups), | |
| 11, | |
| figsize=(20, round(10/4 * len(visual_emb_gam_setups))) | |
| ) | |
| gs_ax = axs[0, 0] if multiple_setups else axs[0] | |
| gs = gs_ax.get_gridspec() | |
| ax_rm = axs[:, 0] if multiple_setups else [axs[0]] | |
| for ax in ax_rm: | |
| ax.remove() | |
| ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0]) | |
| # plot original image | |
| ax_orig_img.imshow(input_img) | |
| ax_orig_img.axis('off') | |
| # plot patch contributions | |
| axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]] | |
| for i, setup in enumerate(visual_emb_gam_setups): | |
| vmin = patch_contributions[setup].min() | |
| vmax = patch_contributions[setup].max() | |
| for j in range(10): | |
| ax = axs_maps[i][j] | |
| sns.heatmap( | |
| patch_contributions[setup][j], | |
| ax=ax, | |
| square=True, | |
| vmin=vmin, | |
| vmax=vmax, | |
| cbar=show_cbars | |
| ) | |
| if show_scores: | |
| ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}') | |
| if j == 0: | |
| ax.set_ylabel(setup) | |
| if i == 0: | |
| ax.set_title(labels[j]) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| plt.tight_layout() | |
| return fig | |
| description = 'Visualize the patch contributions of [visual Emb-GAMs](https://huggingface.co/models?other=visual%20emb-gam) to class labels.' | |
| article = '''An extension of [Emb-GAMs](https://arxiv.org/abs/2209.11799), visual Emb-GAMs classify images by embedding images, taking intermediate representations correponding to different spatial regions, summing these up and predicting a class label from the sum using a GAM. | |
| The use of a sum of embeddings allows us to visualize which regions of an image contributed positive or negatively to each class score. | |
| No paper yet, but you can refer to this [tweet](https://twitter.com/patrick_j_ramos/status/1586992857969147904?s=20&t=5-j5gKK0FpZOgzR_9Wdm1g). Also, check out the original [Emb-GAM paper](https://arxiv.org/abs/2209.11799).''' | |
| demo = gr.Interface( | |
| fn=visualize, | |
| inputs=[ | |
| gr.Image(shape=(224, 224), type='pil', label='Input image'), | |
| gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'), | |
| gr.Checkbox(label='Show scores'), | |
| gr.Checkbox(label='Show color bars') | |
| ], | |
| outputs=[ | |
| gr.Plot(label='Patch contributions'), | |
| ], | |
| examples=[[path,setups,False,False] for path in glob.glob('examples/*')], | |
| title='Visual Emb-GAM Probing', | |
| description=description, | |
| article=article | |
| ) | |
| demo.launch(debug=True) |