Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import requests | |
| import spaces | |
| import torch | |
| import torchvision.transforms as T | |
| import types | |
| import albumentations as A | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from tqdm import tqdm | |
| cmap = plt.get_cmap("tab20") | |
| imagenet_transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) | |
| def get_bg_mask(image): | |
| # detect background based on the four edges | |
| image = np.array(image) | |
| if np.all(image[:, 0] == image[0, 0]) and np.all(image[:, -1] == image[0, -1]) \ | |
| and np.all(image[0, :] == image[0, 0]) and np.all(image[-1, :] == image[-1, 0]) \ | |
| and np.all(image[0, 0] == image[0, -1]) and np.all(image[0, 0] == image[-1, 0]) \ | |
| and np.all(image[0, 0] == image[-1, -1]): | |
| return np.any(image != image[0, 0], -1) | |
| return np.ones_like(image[:, :, 0], dtype=bool) | |
| def download_image(url, save_path): | |
| response = requests.get(url) | |
| with open(save_path, 'wb') as file: | |
| file.write(response.content) | |
| def process_image(image, res, patch_size, decimation=4): | |
| image = torch.from_numpy(np.array(image) / 255.).float().permute(2, 0, 1).to(device) | |
| tgt_size = (int(image.shape[-2] * res / image.shape[-1]), res) | |
| if image.shape[-2] > image.shape[-1]: | |
| tgt_size = (res, int(image.shape[-1] * res / image.shape[-2])) | |
| patch_h, patch_w = tgt_size[0] // decimation, tgt_size[1] // decimation | |
| image_resized = T.functional.resize(image, (patch_h * patch_size, patch_w * patch_size)) | |
| image_resized = imagenet_transform(image_resized) | |
| return image_resized | |
| def generate_grid(x, y, stride): | |
| x_coords = np.arange(0, x, grid_stride) | |
| y_coords = np.arange(0, y, grid_stride) | |
| x_mesh, y_mesh = np.meshgrid(x_coords, y_coords) | |
| kp = np.column_stack((x_mesh.ravel(), y_mesh.ravel())).astype(float) | |
| return kp | |
| def pca(feat, pca_dim=3): | |
| feat_flattened = feat | |
| mean = torch.mean(feat_flattened, dim=0) | |
| centered_features = feat_flattened - mean | |
| U, S, V = torch.pca_lowrank(centered_features, q=pca_dim) | |
| reduced_features = torch.matmul(centered_features, V[:, :pca_dim]) | |
| return reduced_features | |
| def co_pca(feat1, feat2, pca_dim=3): | |
| co_feats = torch.cat((feat1.reshape(-1, feat1.shape[-1]), feat2.reshape(-1, feat2.shape[-1])), dim=0) | |
| feats = pca(co_feats) | |
| feat1_pca = feats[:feat1.shape[0]*feat1.shape[1]].reshape(feat1.shape[0], feat1.shape[1], -1) | |
| feat2_pca = feats[feat1.shape[0]*feat1.shape[1]:].reshape(feat2.shape[0], feat2.shape[1], -1) | |
| return feat1_pca, feat2_pca | |
| def draw_correspondence(feat1, feat2, color1, mask1, mask2): | |
| original_mask2_shape = mask2.shape | |
| mask1, mask2 = mask1.reshape(-1), mask2.reshape(-1) | |
| distances = torch.cdist(feat1.reshape(-1, feat1.shape[-1])[mask1], feat2.reshape(-1, feat2.shape[-1])[mask2]) | |
| nearest = torch.argmin(distances, dim=0) | |
| color2 = torch.zeros((mask2.shape[0], 3,)).to(device) | |
| color2[mask2] = color1.reshape(-1, 3)[mask1][nearest] | |
| color2 = color2.reshape(*original_mask2_shape, 3) | |
| return color2 | |
| def load_model(options): | |
| original_models = {} | |
| fine_models = {} | |
| for option in tqdm(options): | |
| print('Please wait ...') | |
| print('loading weights of ', option) | |
| fine_models[option] = torch.hub.load(".", model_card[option], source='local').to(device) | |
| original_models[option] = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=fine_models[option].backbone_name).eval().to(device) | |
| print('Done! Now play the demo :)') | |
| return original_models, fine_models | |
| if __name__ == "__main__": | |
| if torch.cuda.is_available(): | |
| device = torch.device('cuda') | |
| else: | |
| device = torch.device('cpu') | |
| print("device: ") | |
| print(device) | |
| example_dir = "examples" | |
| os.makedirs(example_dir, exist_ok=True) | |
| image_input1 = gr.Image(label="Choose an image:", | |
| height=500, | |
| type="pil", | |
| image_mode='RGB', | |
| sources=['upload', 'webcam', 'clipboard'] | |
| ) | |
| image_input2 = gr.Image(label="Choose another image:", | |
| height=500, | |
| type="pil", | |
| image_mode='RGB', | |
| sources=['upload', 'webcam', 'clipboard'] | |
| ) | |
| options = ['DINOv2-Base'] | |
| model_option = gr.Radio(options, value="DINOv2-Base", label='Choose a 2D foundation model') | |
| model_card = { | |
| "DINOv2-Base": "dinov2_base", | |
| } | |
| os.environ['TORCH_HOME'] = '/tmp/.cache' | |
| # os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache' | |
| # Pre-load all models | |
| original_models, fine_models = load_model(options) | |
| def main(image1, image2, model_option, kmeans_num): | |
| if image1 is None or image2 is None: | |
| return None | |
| # Select model | |
| original_model = original_models[model_option] | |
| fine_model = fine_models[model_option] | |
| images_resized = [process_image(image, 640, 14, decimation=8) for image in [image1, image2]] | |
| masks = [torch.from_numpy(get_bg_mask(image)).to(device) for image in [image1, image2]] | |
| feat_shapes = [(images_resized[0].shape[-2] // 14, images_resized[0].shape[-1] // 14), | |
| (images_resized[1].shape[-2] // 14, images_resized[1].shape[-1] // 14)] | |
| masks_resized = [T.functional.resize(mask.float()[None], feat_shape, | |
| interpolation=T.functional.InterpolationMode.NEAREST_EXACT)[0] | |
| for mask, feat_shape in zip(masks, feat_shapes)] | |
| with torch.no_grad(): | |
| original_feats = [original_model.forward_features(image[None])['x_norm_patchtokens'].reshape(*feat_shape, -1) | |
| for image, feat_shape in zip(images_resized, feat_shapes)] | |
| original_feats = [F.normalize(feat, p=2, dim=-1) for feat in original_feats] | |
| original_color1 = torch.zeros((original_feats[0].shape[0] * original_feats[0].shape[1], 3,)).to(device) | |
| color = pca((original_feats[0][masks_resized[0] > 0]), 3) | |
| color = (color - color.min()) / (color.max() - color.min()) | |
| original_color1[masks_resized[0].reshape(-1) > 0] = color | |
| original_color1 = original_color1.reshape(*original_feats[0].shape[:2], 3) | |
| original_color2 = draw_correspondence(original_feats[0], original_feats[1], original_color1, | |
| masks_resized[0] > 0, masks_resized[1] > 0) | |
| fine_feats = [fine_model.dinov2.forward_features(image[None])['x_norm_patchtokens'].reshape(*feat_shape, -1) | |
| for image, feat_shape in zip(images_resized, feat_shapes)] | |
| fine_feats = [fine_model.refine_conv(feat[None].permute(0, 3, 1, 2)).permute(0, 2, 3, 1)[0] for feat in fine_feats] | |
| fine_feats = [F.normalize(feat, p=2, dim=-1) for feat in fine_feats] | |
| fine_color2 = draw_correspondence(fine_feats[0], fine_feats[1], original_color1, | |
| masks_resized[0] > 0, masks_resized[1] > 0) | |
| fig, ax = plt.subplots(2, 2, squeeze=False) | |
| ax[0][0].imshow(original_color1.cpu().numpy()) | |
| ax[0][1].text(-0.1, 0.5, "Original " + model_option, fontsize=7, rotation=90, va='center', transform=ax[0][1].transAxes) | |
| ax[0][1].imshow(original_color2.cpu().numpy()) | |
| # ax[1][0].imshow(fine_color1.cpu().numpy()) | |
| ax[1][1].text(-0.1, 0.5, "Finetuned " + model_option, fontsize=7, rotation=90, va='center', transform=ax[1][1].transAxes) | |
| ax[1][1].imshow(fine_color2.cpu().numpy()) | |
| for xx in ax: | |
| for x in xx: | |
| x.xaxis.set_major_formatter(plt.NullFormatter()) | |
| x.yaxis.set_major_formatter(plt.NullFormatter()) | |
| x.set_xticks([]) | |
| x.set_yticks([]) | |
| x.axis('off') | |
| plt.tight_layout() | |
| plt.close(fig) | |
| return fig | |
| demo = gr.Interface( | |
| title="<div> \ | |
| <h1>3DCorrEnhance</h1> \ | |
| <h2>Multiview Equivariance Improves 3D Correspondence Understanding with Minimal Feature Finetuning</h2> \ | |
| <h2>ICLR 2025</h2> \ | |
| </div>", | |
| description="<div style='display: flex; justify-content: center; align-items: center; text-align: center;'> \ | |
| <a href='https://arxiv.org/abs/2411.19458'><img src='https://img.shields.io/badge/arXiv-2411.19458-red'></a> \ | |
| \ | |
| <a href='#'><img src='https://img.shields.io/badge/Project_Page-3DCorrEnhance-green' alt='Project Page (Coming soon)'></a> \ | |
| \ | |
| <a href='https://github.com/qq456cvb/3DCorrEnhance'><img src='https://img.shields.io/badge/Github-Code-blue'></a> \ | |
| </div>", | |
| fn=main, | |
| inputs=[image_input1, image_input2, model_option], | |
| outputs="plot", | |
| examples=[ | |
| ["examples/objs/1-1.png", "examples/objs/1-2.png", "DINOv2-Base"], | |
| ["examples/scenes/1-1.jpg", "examples/scenes/1-2.jpg", "DINOv2-Base"], | |
| ["examples/scenes/2-1.jpg", "examples/scenes/2-2.jpg", "DINOv2-Base"], | |
| ], | |
| cache_examples=True) | |
| demo.launch() | |