Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import json | |
| import math | |
| import pickle | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, Mapping | |
| import cv2 | |
| import matplotlib.cm as cm | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| import tqdm | |
| from PIL import Image | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| from sklearn.decomposition import PCA | |
| from torch.nn.parameter import Parameter | |
| from torch.utils.data import ConcatDataset, DataLoader, Subset | |
| from torchvision.transforms import functional | |
| class _LoRA_qkv(nn.Module): | |
| """ | |
| In Dinov2 it is implemented as | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) | |
| """ | |
| def __init__( | |
| self, | |
| qkv: nn.Module, | |
| linear_a_q: nn.Module, | |
| linear_b_q: nn.Module, | |
| linear_a_v: nn.Module, | |
| linear_b_v: nn.Module, | |
| ): | |
| super().__init__() | |
| self.qkv = qkv | |
| self.linear_a_q = linear_a_q | |
| self.linear_b_q = linear_b_q | |
| self.linear_a_v = linear_a_v | |
| self.linear_b_v = linear_b_v | |
| self.dim = qkv.in_features | |
| self.w_identity = torch.eye(qkv.in_features) | |
| def forward(self, x): | |
| qkv = self.qkv(x) # B,N,3*org_C | |
| new_q = self.linear_b_q(self.linear_a_q(x)) | |
| new_v = self.linear_b_v(self.linear_a_v(x)) | |
| qkv[:, :, : self.dim] += new_q | |
| qkv[:, :, -self.dim:] += new_v | |
| return qkv | |
| def sigmoid(tensor, temp=1.0): | |
| """ temperature controlled sigmoid | |
| takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp | |
| """ | |
| exponent = -tensor / temp | |
| # clamp the input tensor for stability | |
| exponent = torch.clamp(exponent, min=-50, max=50) | |
| y = 1.0 / (1.0 + torch.exp(exponent)) | |
| return y | |
| def interpolate_features(descriptors, pts, h, w, normalize=True, patch_size=14, stride=14): | |
| last_coord_h = ( (h - patch_size) // stride ) * stride + (patch_size / 2) | |
| last_coord_w = ( (w - patch_size) // stride ) * stride + (patch_size / 2) | |
| ah = 2 / (last_coord_h - (patch_size / 2)) | |
| aw = 2 / (last_coord_w - (patch_size / 2)) | |
| bh = 1 - last_coord_h * 2 / ( last_coord_h - ( patch_size / 2 )) | |
| bw = 1 - last_coord_w * 2 / ( last_coord_w - ( patch_size / 2 )) | |
| a = torch.tensor([[aw, ah]]).to(pts).float() | |
| b = torch.tensor([[bw, bh]]).to(pts).float() | |
| keypoints = a * pts + b | |
| # Expand dimensions for grid sampling | |
| keypoints = keypoints.unsqueeze(-3) # Shape becomes [batch_size, 1, num_keypoints, 2] | |
| # Interpolate using bilinear sampling | |
| interpolated_features = F.grid_sample(descriptors, keypoints, align_corners=True, padding_mode='border') | |
| # interpolated_features will have shape [batch_size, channels, 1, num_keypoints] | |
| interpolated_features = interpolated_features.squeeze(-2) | |
| return F.normalize(interpolated_features, dim=1) if normalize else interpolated_features | |
| class FinetuneDINO(pl.LightningModule): | |
| def __init__(self, r, backbone_size, reg=False, datasets=None): | |
| super().__init__() | |
| assert r > 0 | |
| self.backbone_size = backbone_size | |
| self.backbone_archs = { | |
| "small": "vits14", | |
| "base": "vitb14", | |
| "large": "vitl14", | |
| "giant": "vitg14", | |
| } | |
| self.embedding_dims = { | |
| "small": 384, | |
| "base": 768, | |
| "large": 1024, | |
| "giant": 1536, | |
| } | |
| self.backbone_arch = self.backbone_archs[self.backbone_size] | |
| if reg: | |
| self.backbone_arch = f"{self.backbone_arch}_reg" | |
| self.embedding_dim = self.embedding_dims[self.backbone_size] | |
| self.backbone_name = f"dinov2_{self.backbone_arch}" | |
| dinov2 = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=self.backbone_name) | |
| self.datasets = datasets | |
| self.lora_layer = list(range(len(dinov2.blocks))) # Only apply lora to the image encoder by default | |
| # create for storage, then we can init them or load weights | |
| self.w_As = [] # These are linear layers | |
| self.w_Bs = [] | |
| # freeze first | |
| for param in dinov2.parameters(): | |
| param.requires_grad = False | |
| # finetune the last 4 blocks | |
| for t_layer_i, blk in enumerate(dinov2.blocks[-4:]): | |
| # If we only want few lora layer instead of all | |
| if t_layer_i not in self.lora_layer: | |
| continue | |
| w_qkv_linear = blk.attn.qkv | |
| self.dim = w_qkv_linear.in_features | |
| w_a_linear_q = nn.Linear(self.dim, r, bias=False) | |
| w_b_linear_q = nn.Linear(r, self.dim, bias=False) | |
| w_a_linear_v = nn.Linear(self.dim, r, bias=False) | |
| w_b_linear_v = nn.Linear(r, self.dim, bias=False) | |
| self.w_As.append(w_a_linear_q) | |
| self.w_Bs.append(w_b_linear_q) | |
| self.w_As.append(w_a_linear_v) | |
| self.w_Bs.append(w_b_linear_v) | |
| blk.attn.qkv = _LoRA_qkv( | |
| w_qkv_linear, | |
| w_a_linear_q, | |
| w_b_linear_q, | |
| w_a_linear_v, | |
| w_b_linear_v, | |
| ) | |
| self.reset_parameters() | |
| self.dinov2 = dinov2 | |
| self.downsample_factor = 8 | |
| self.refine_conv = nn.Conv2d(self.embedding_dim, self.embedding_dim, kernel_size=3, stride=1, padding=1) | |
| self.thresh3d_pos = 5e-3 | |
| self.thres3d_neg = 0.1 | |
| self.patch_size = 14 | |
| self.target_res = 640 | |
| self.input_transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) | |
| def reset_parameters(self) -> None: | |
| for w_A in self.w_As: | |
| nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) | |
| for w_B in self.w_Bs: | |
| nn.init.zeros_(w_B.weight) | |
| def on_save_checkpoint(self, checkpoint: Dict[str, Any]): | |
| num_layer = len(self.w_As) # actually, it is half | |
| a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} | |
| b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} | |
| checkpoint['state_dict'] = { | |
| 'refine_conv': self.refine_conv.state_dict(), | |
| } | |
| checkpoint.update(a_tensors) | |
| checkpoint.update(b_tensors) | |
| def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): | |
| pass | |
| def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
| # print(checkpoint.keys()) | |
| self.refine_conv.load_state_dict(checkpoint['state_dict']['refine_conv']) | |
| for i, w_A_linear in enumerate(self.w_As): | |
| saved_key = f"w_a_{i:03d}" | |
| saved_tensor = checkpoint[saved_key] | |
| w_A_linear.weight = Parameter(saved_tensor) | |
| for i, w_B_linear in enumerate(self.w_Bs): | |
| saved_key = f"w_b_{i:03d}" | |
| saved_tensor = checkpoint[saved_key] | |
| w_B_linear.weight = Parameter(saved_tensor) | |
| self.loaded = True | |
| def get_nearest(self, query, database): | |
| dist = torch.cdist(query, database) | |
| min_dist, min_idx = torch.min(dist, -1) | |
| return min_dist, min_idx | |
| def get_feature(self, rgbs, pts, normalize=True): | |
| tgt_size = (int(rgbs.shape[-2] * self.target_res / rgbs.shape[-1]), self.target_res) | |
| if rgbs.shape[-2] > rgbs.shape[-1]: | |
| tgt_size = (self.target_res, int(rgbs.shape[-1] * self.target_res / rgbs.shape[-2])) | |
| patch_h, patch_w = tgt_size[0] // self.downsample_factor, tgt_size[1] // self.downsample_factor | |
| rgb_resized = functional.resize(rgbs, (patch_h * self.patch_size, patch_w * self.patch_size)) | |
| resize_factor = [(patch_w * self.patch_size) / rgbs.shape[-1], (patch_h * self.patch_size) / rgbs.shape[-2]] | |
| pts = pts * torch.tensor(resize_factor).to(pts.device) | |
| result = self.dinov2.forward_features(self.input_transform(rgb_resized)) | |
| feature = result['x_norm_patchtokens'].reshape(rgb_resized.shape[0], patch_h, patch_w, -1).permute(0, 3, 1, 2) | |
| feature = self.refine_conv(feature) | |
| feature = interpolate_features(feature, pts, h=patch_h * 14, w=patch_w * 14, normalize=False).permute(0, 2, 1) | |
| if normalize: | |
| feature = F.normalize(feature, p=2, dim=-1) | |
| return feature | |
| def get_feature_wo_kp(self, rgbs, normalize=True): | |
| tgt_size = (int(rgbs.shape[-2] * self.target_res / rgbs.shape[-1]), self.target_res) | |
| if rgbs.shape[-2] > rgbs.shape[-1]: | |
| tgt_size = (self.target_res, int(rgbs.shape[-1] * self.target_res / rgbs.shape[-2])) | |
| patch_h, patch_w = tgt_size[0] // self.downsample_factor, tgt_size[1] // self.downsample_factor | |
| rgb_resized = functional.resize(rgbs, (patch_h * self.patch_size, patch_w * self.patch_size)) | |
| result = self.dinov2.forward_features(self.input_transform(rgb_resized)) | |
| feature = result['x_norm_patchtokens'].reshape(rgbs.shape[0], patch_h, patch_w, -1).permute(0, 3, 1, 2) | |
| feature = self.refine_conv(feature) | |
| feature = functional.resize(feature, (rgbs.shape[-2], rgbs.shape[-1])).permute(0, 2, 3, 1) | |
| if normalize: | |
| feature = F.normalize(feature, p=2, dim=-1) | |
| return feature | |
| def training_step(self, batch, batch_idx): | |
| # print(batch['obj_name_1']) | |
| rgb_1, pts2d_1, pts3d_1 = batch['rgb_1'], batch['pts2d_1'], batch['pts3d_1'] | |
| rgb_2, pts2d_2, pts3d_2 = batch['rgb_2'], batch['pts2d_2'], batch['pts3d_2'] | |
| desc_1 = self.get_feature(rgb_1, pts2d_1, normalize=True) | |
| desc_2 = self.get_feature(rgb_2, pts2d_2, normalize=True) | |
| kp3d_dist = torch.cdist(pts3d_1, pts3d_2) # B x S x T | |
| sim = torch.bmm(desc_1, desc_2.transpose(-1, -2)) # B x S x T | |
| pos_idxs = torch.nonzero(kp3d_dist < self.thresh3d_pos, as_tuple=False) | |
| pos_sim = sim[pos_idxs[:, 0], pos_idxs[:, 1], pos_idxs[:, 2]] | |
| rpos = sigmoid(pos_sim - 1., temp=0.01) + 1 # si = 1 # pos | |
| neg_mask = kp3d_dist[pos_idxs[:, 0], pos_idxs[:, 1]] > self.thres3d_neg # pos x T | |
| rall = rpos + torch.sum(sigmoid(sim[pos_idxs[:, 0], pos_idxs[:, 1]] - 1., temp=0.01) * neg_mask.float(), -1) # pos | |
| ap1 = rpos / rall | |
| # change teh order | |
| rpos = sigmoid(1. - pos_sim, temp=0.01) + 1 # si = 1 # pos | |
| neg_mask = kp3d_dist[pos_idxs[:, 0], pos_idxs[:, 1]] > self.thres3d_neg # pos x T | |
| rall = rpos + torch.sum(sigmoid(sim[pos_idxs[:, 0], pos_idxs[:, 1]] - pos_sim[:, None].repeat(1, sim.shape[-1]), temp=0.01) * neg_mask.float(), -1) # pos | |
| ap2 = rpos / rall | |
| ap = (ap1 + ap2) / 2 | |
| loss = torch.mean(1. - ap) | |
| self.log('loss', loss, prog_bar=True) | |
| return loss | |
| def configure_optimizers(self): | |
| return torch.optim.AdamW([layer.weight for layer in self.w_As] | |
| + [layer.weight for layer in self.w_Bs] | |
| + list(self.refine_conv.parameters()), lr=1e-5, weight_decay=1e-4) |