Spaces:
Running
Running
| import os | |
| import torch | |
| import torchvision.transforms as tvf | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from dust3r.utils.device import to_numpy | |
| from dust3r.inference import inference | |
| from dust3r.model import AsymmetricCroCo3DStereo | |
| from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
| from utils.dust3r_utils import compute_global_alignment | |
| from mast3r.model import AsymmetricMASt3R | |
| from mast3r.cloud_opt.sparse_ga import sparse_global_alignment | |
| from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| class TorchPCA(object): | |
| def __init__(self, n_components): | |
| self.n_components = n_components | |
| def fit(self, X): | |
| self.mean_ = X.mean(dim=0) | |
| unbiased = X - self.mean_.unsqueeze(0) | |
| U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=50) | |
| self.components_ = V.T | |
| self.singular_values_ = S | |
| return self | |
| def transform(self, X): | |
| t0 = X - self.mean_.unsqueeze(0) | |
| projected = t0 @ self.components_.T | |
| return projected | |
| def pca(stacked_feat, dim): | |
| flattened_feats = [] | |
| for feat in stacked_feat: | |
| H, W, C = feat.shape | |
| feat = feat.reshape(H * W, C).detach() | |
| flattened_feats.append(feat) | |
| x = torch.cat(flattened_feats, dim=0) | |
| fit_pca = TorchPCA(n_components=dim).fit(x) | |
| projected_feats = [] | |
| for feat in stacked_feat: | |
| H, W, C = feat.shape | |
| feat = feat.reshape(H * W, C).detach() | |
| x_red = fit_pca.transform(feat) | |
| projected_feats.append(x_red.reshape(H, W, dim)) | |
| projected_feats = torch.stack(projected_feats) | |
| return projected_feats | |
| def upsampler(feature, upsampled_height, upsampled_width, max_chunk=None): | |
| """ | |
| Upsample the feature tensor to the specified height and width. | |
| Args: | |
| - feature (torch.Tensor): The input tensor with size [B, H, W, C]. | |
| - upsampled_height (int): The target height after upsampling. | |
| - upsampled_width (int): The target width after upsampling. | |
| Returns: | |
| - upsampled_feature (torch.Tensor): The upsampled tensor with size [B, upsampled_height, upsampled_width, C]. | |
| """ | |
| # Permute the tensor to [B, C, H, W] for interpolation | |
| feature = feature.permute(0, 3, 1, 2) | |
| # Perform the upsampling | |
| if max_chunk: | |
| upsampled_chunks = [] | |
| for i in range(0, len(feature), max_chunk): | |
| chunk = feature[i:i+max_chunk] | |
| upsampled_chunk = F.interpolate(chunk, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False) | |
| upsampled_chunks.append(upsampled_chunk) | |
| upsampled_feature = torch.cat(upsampled_chunks, dim=0) | |
| else: | |
| upsampled_feature = F.interpolate(feature, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False) | |
| # Permute back to [B, H, W, C] | |
| upsampled_feature = upsampled_feature.permute(0, 2, 3, 1) | |
| return upsampled_feature | |
| def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None): | |
| """ | |
| Visualize features and corresponding images, and save the result. | |
| Args: | |
| features (torch.Tensor): Feature tensor with shape [B, H, W, C]. | |
| images (list): List of dictionaries containing images with keys 'img'. Each image tensor has shape [1, 3, H, W] | |
| and values in the range [-1, 1]. | |
| save_dir (str): Directory to save the resulting visualization. | |
| feat_type (list): List of feature types. | |
| file_name (str): Name of the file to save. | |
| """ | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| from matplotlib import pyplot as plt | |
| import torchvision.utils as vutils | |
| assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)" | |
| B, H, W, C = features.size() | |
| features = features[..., dim-9:] | |
| # Normalize the 3-dimensional feature to range [0, 1] | |
| features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values | |
| features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values | |
| features = (features - features_min) / (features_max - features_min) | |
| ##### Save individual feature maps | |
| # # Create subdirectory for feature visualizations | |
| # feat_dir = os.path.join(save_dir, 'feature_maps') | |
| # if feat_type: | |
| # feat_dir = os.path.join(feat_dir, '-'.join(feat_type)) | |
| # os.makedirs(feat_dir, exist_ok=True) | |
| # for i in range(B): | |
| # # Extract and save the feature map (channels 3-6) | |
| # feat_map = features[i, :, :, 3:6].permute(2, 0, 1) # [3, H, W] | |
| # save_path = os.path.join(feat_dir, f'{i}_feat.png') | |
| # vutils.save_image(feat_map, save_path, normalize=False) | |
| # return feat_dir | |
| ##### Save feature maps in a single image | |
| # Set the size of the plot | |
| fig, axes = plt.subplots(B, 4, figsize=(W*4*0.01, H*B*0.01)) | |
| for i in range(B): | |
| # Get the original image | |
| image_tensor = images[i]['img'] | |
| assert image_tensor.dim() == 4 and image_tensor.size(0) == 1 and image_tensor.size(1) == 3, "Image tensor must have shape [1, 3, H, W]" | |
| image = image_tensor.squeeze(0).permute(1, 2, 0).numpy() # Convert to (H, W, 3) | |
| # Scale image values from [-1, 1] to [0, 1] | |
| image = (image + 1) / 2 | |
| ax = axes[i, 0] if B > 1 else axes[0] | |
| ax.imshow(image) | |
| ax.axis('off') | |
| # Visualize each 3-dimensional feature | |
| for j in range(3): | |
| ax = axes[i, j+1] if B > 1 else axes[j+1] | |
| if j * 3 < min(C, dim): # Check if the feature channels are available | |
| feature_to_plot = features[i, :, :, j*3:(j+1)*3].cpu().numpy() | |
| ax.imshow(feature_to_plot) | |
| else: # Plot white image if features are not available | |
| ax.imshow(torch.ones(H, W, 3).numpy()) | |
| ax.axis('off') | |
| # Reduce margins and spaces between images | |
| plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01) | |
| # Save the entire plot | |
| if file_name is None: | |
| file_name = f'feat_dim{dim-9}-{dim}' | |
| if feat_type: | |
| feat_type_str = '-'.join(feat_type) | |
| file_name = file_name + f'_{feat_type_str}' | |
| save_path = os.path.join(save_dir, file_name + '.png') | |
| plt.savefig(save_path, bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| return save_path | |
| #### Open it if you visualize feature maps in Feat2GS's teaser | |
| # import matplotlib.colors as mcolors | |
| # from PIL import Image | |
| # morandi_colors = [ | |
| # '#8AA2A9', '#C98474', '#F2D0A9', '#8D9F87', '#A7A7A7', '#D98E73', '#B24C33', '#5E7460', '#4A6B8A', '#B2CBC2', | |
| # '#BBC990', '#6B859E', '#B45342', '#4E0000', '#3D0000', '#2C0000', '#1B0000', '#0A0000', '#DCAC99', '#6F936B', | |
| # '#EBA062', '#FED273', '#9A8EB4', '#706052', '#E9E5E5', '#C4D8D2', '#F2CBBD', '#F6F9F1', '#C5CABC', '#A3968B', | |
| # '#5C6974', '#BE7B6E', '#C67752', '#C18830', '#8C956C', '#CAC691', '#819992', '#4D797F', '#95AEB2', '#B6C4CF', | |
| # '#84291C', '#B9551F', '#A96400', '#374B6C', '#C8B493', '#677D5D', '#9882A2', '#2D5F53', '#D2A0AC', '#658D9A', | |
| # '#9A7265', '#EFE1D2', '#DDD8D1', '#D2C6BC', '#E3C9BC', '#B8AB9F', '#D8BEA4', '#E0D4C5', '#B8B8B6', '#D0CAC3', | |
| # '#9AA8B5', '#BBC9B9', '#E3E8D8', '#ADB3A4', '#C5C9BB', '#A3968B', '#C2A995', '#EDE1D1', '#EDE8E1', '#EDEBE1', | |
| # '#CFCFCC', '#AABAC6', '#DCDEE0', '#EAE5E7', '#B7AB9F', '#F7EFE3', '#DED8CF', '#ABCA99', '#C5CD8F', '#959491', | |
| # '#FFE481', '#C18E99', '#B07C86', '#9F6A73', '#8E5860', '#DEAD44', '#CD9B31', '#BC891E', '#AB770B', '#9A6500', | |
| # '#778144', '#666F31', '#555D1E', '#444B0B', '#333900', '#67587B', '#564668', '#684563', '#573350', '#684550', | |
| # '#57333D', '#46212A', '#350F17', '#240004', | |
| # ] | |
| # def rgb_to_hsv(rgb): | |
| # rgb = rgb.clamp(0, 1) | |
| # cmax, cmax_idx = rgb.max(dim=-1) | |
| # cmin = rgb.min(dim=-1).values | |
| # diff = cmax - cmin | |
| # h = torch.zeros_like(cmax) | |
| # h[cmax_idx == 0] = (((rgb[..., 1] - rgb[..., 2]) / diff) % 6)[cmax_idx == 0] | |
| # h[cmax_idx == 1] = (((rgb[..., 2] - rgb[..., 0]) / diff) + 2)[cmax_idx == 1] | |
| # h[cmax_idx == 2] = (((rgb[..., 0] - rgb[..., 1]) / diff) + 4)[cmax_idx == 2] | |
| # h[diff == 0] = 0 # If cmax == cmin | |
| # h = h / 6 | |
| # s = torch.zeros_like(cmax) | |
| # s[cmax != 0] = (diff / cmax)[cmax != 0] | |
| # v = cmax | |
| # return torch.stack([h, s, v], dim=-1) | |
| # def hsv_to_rgb(hsv): | |
| # h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] | |
| # c = v * s | |
| # x = c * (1 - torch.abs((h * 6) % 2 - 1)) | |
| # m = v - c | |
| # rgb = torch.zeros_like(hsv) | |
| # mask = (h < 1/6) | |
| # rgb[mask] = torch.stack([c[mask], x[mask], torch.zeros_like(x[mask])], dim=-1) | |
| # mask = (1/6 <= h) & (h < 2/6) | |
| # rgb[mask] = torch.stack([x[mask], c[mask], torch.zeros_like(x[mask])], dim=-1) | |
| # mask = (2/6 <= h) & (h < 3/6) | |
| # rgb[mask] = torch.stack([torch.zeros_like(x[mask]), c[mask], x[mask]], dim=-1) | |
| # mask = (3/6 <= h) & (h < 4/6) | |
| # rgb[mask] = torch.stack([torch.zeros_like(x[mask]), x[mask], c[mask]], dim=-1) | |
| # mask = (4/6 <= h) & (h < 5/6) | |
| # rgb[mask] = torch.stack([x[mask], torch.zeros_like(x[mask]), c[mask]], dim=-1) | |
| # mask = (5/6 <= h) | |
| # rgb[mask] = torch.stack([c[mask], torch.zeros_like(x[mask]), x[mask]], dim=-1) | |
| # return rgb + m.unsqueeze(-1) | |
| # def interpolate_colors(colors, n_colors): | |
| # # Convert colors to RGB tensor | |
| # rgb_colors = torch.tensor([mcolors.to_rgb(color) for color in colors]) | |
| # # Convert RGB to HSV | |
| # hsv_colors = rgb_to_hsv(rgb_colors) | |
| # # Sort by hue | |
| # sorted_indices = torch.argsort(hsv_colors[:, 0]) | |
| # sorted_hsv_colors = hsv_colors[sorted_indices] | |
| # # Create interpolation indices | |
| # indices = torch.linspace(0, len(sorted_hsv_colors) - 1, n_colors) | |
| # # Perform interpolation | |
| # interpolated_hsv = torch.stack([ | |
| # torch.lerp(sorted_hsv_colors[int(i)], | |
| # sorted_hsv_colors[min(int(i) + 1, len(sorted_hsv_colors) - 1)], | |
| # i - int(i)) | |
| # for i in indices | |
| # ]) | |
| # # Convert interpolated result back to RGB | |
| # interpolated_rgb = hsv_to_rgb(interpolated_hsv) | |
| # return interpolated_rgb | |
| # def project_to_morandi(features, morandi_colors): | |
| # features_flat = features.reshape(-1, 3) | |
| # distances = torch.cdist(features_flat, morandi_colors) | |
| # # Get the indices of the closest colors | |
| # closest_color_indices = torch.argmin(distances, dim=1) | |
| # # Use the closest Morandi colors directly | |
| # features_morandi = morandi_colors[closest_color_indices] | |
| # features_morandi = features_morandi.reshape(features.shape) | |
| # return features_morandi | |
| # def smooth_color_transform(features, morandi_colors, smoothness=0.1): | |
| # features_flat = features.reshape(-1, 3) | |
| # distances = torch.cdist(features_flat, morandi_colors) | |
| # # Calculate weights | |
| # weights = torch.exp(-distances / smoothness) | |
| # weights = weights / weights.sum(dim=1, keepdim=True) | |
| # # Weighted average | |
| # features_morandi = torch.matmul(weights, morandi_colors) | |
| # features_morandi = features_morandi.reshape(features.shape) | |
| # return features_morandi | |
| # def histogram_matching(source, template): | |
| # """ | |
| # Match the histogram of the source tensor to that of the template tensor. | |
| # :param source: Source tensor with shape [B, H, W, 3] | |
| # :param template: Template tensor with shape [N, 3], where N is the number of colors | |
| # :return: Source tensor after histogram matching | |
| # """ | |
| # def match_cumulative_cdf(source, template): | |
| # src_values, src_indices, src_counts = torch.unique(source, return_inverse=True, return_counts=True) | |
| # tmpl_values, tmpl_counts = torch.unique(template, return_counts=True) | |
| # src_quantiles = torch.cumsum(src_counts.float(), 0) / source.numel() | |
| # tmpl_quantiles = torch.cumsum(tmpl_counts.float(), 0) / template.numel() | |
| # idx = torch.searchsorted(tmpl_quantiles, src_quantiles) | |
| # idx = torch.clamp(idx, 1, len(tmpl_quantiles)-1) | |
| # slope = (tmpl_values[idx] - tmpl_values[idx-1]) / (tmpl_quantiles[idx] - tmpl_quantiles[idx-1]) | |
| # interp_a_values = torch.lerp(tmpl_values[idx-1], tmpl_values[idx], | |
| # (src_quantiles - tmpl_quantiles[idx-1]) * slope) | |
| # return interp_a_values[src_indices].reshape(source.shape) | |
| # matched = torch.stack([match_cumulative_cdf(source[..., i], template[:, i]) for i in range(3)], dim=-1) | |
| # return matched | |
| # def process_features(features): | |
| # device = features.device | |
| # n_colors = 1024 | |
| # morandi_colors_tensor = interpolate_colors(morandi_colors, n_colors).to(device) | |
| # # morandi_colors_tensor = torch.tensor([mcolors.to_rgb(color) for color in morandi_colors]).to(device) | |
| # # features_morandi = project_to_morandi(features, morandi_colors_tensor) | |
| # # features_morandi = histogram_matching(features, morandi_colors_tensor) | |
| # features_morandi = smooth_color_transform(features, morandi_colors_tensor, smoothness=0.05) | |
| # return features_morandi.cpu().numpy() | |
| # def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None): | |
| # import matplotlib | |
| # matplotlib.use('Agg') | |
| # import matplotlib.pyplot as plt | |
| # import numpy as np | |
| # import os | |
| # assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)" | |
| # B, H, W, C = features.size() | |
| # # Ensure features have at least 3 channels for RGB visualization | |
| # assert C >= 3, "Features must have at least 3 channels for RGB visualization" | |
| # features = features[..., :3] | |
| # # Normalize features to [0, 1] range | |
| # features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values | |
| # features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values | |
| # features = (features - features_min) / (features_max - features_min) | |
| # features_processed = process_features(features) | |
| # # Create the directory structure | |
| # vis_dir = os.path.join(save_dir, 'vis') | |
| # if feat_type: | |
| # feat_type_str = '-'.join(feat_type) | |
| # vis_dir = os.path.join(vis_dir, feat_type_str) | |
| # os.makedirs(vis_dir, exist_ok=True) | |
| # # Save individual images for each feature map | |
| # for i in range(B): | |
| # if file_name is None: | |
| # file_name = 'feat_morandi' | |
| # save_path = os.path.join(vis_dir, f'{file_name}_{i}.png') | |
| # # Convert to uint8 and save directly | |
| # img = Image.fromarray((features_processed[i] * 255).astype(np.uint8)) | |
| # img.save(save_path) | |
| # print(f"Feature maps have been saved in the directory: {vis_dir}") | |
| # return vis_dir | |
| def mv_visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None): | |
| """ | |
| Visualize features and corresponding images, and save the result. (For MASt3R decoder or head features) | |
| """ | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| from matplotlib import pyplot as plt | |
| import os | |
| B, H, W, _ = features.size() | |
| features = features[..., dim-9:] | |
| # Normalize the 3-dimensional feature to range [0, 1] | |
| features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values | |
| features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values | |
| features = (features - features_min) / (features_max - features_min) | |
| rows = (B + 1) // 2 # Adjust rows for odd B | |
| fig, axes = plt.subplots(rows, 8, figsize=(W*8*0.01, H*rows*0.01)) | |
| for i in range(B//2): | |
| # Odd row: image and features | |
| image = (images[2*i]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2 | |
| axes[i, 0].imshow(image) | |
| axes[i, 0].axis('off') | |
| for j in range(3): | |
| axes[i, j+1].imshow(features[2*i, :, :, j*3:(j+1)*3].cpu().numpy()) | |
| axes[i, j+1].axis('off') | |
| # Even row: image and features | |
| if 2*i + 1 < B: | |
| image = (images[2*i + 1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2 | |
| axes[i, 4].imshow(image) | |
| axes[i, 4].axis('off') | |
| for j in range(3): | |
| axes[i, j+5].imshow(features[2*i + 1, :, :, j*3:(j+1)*3].cpu().numpy()) | |
| axes[i, j+5].axis('off') | |
| # Handle last row if B is odd | |
| if B % 2 != 0: | |
| image = (images[-1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2 | |
| axes[-1, 0].imshow(image) | |
| axes[-1, 0].axis('off') | |
| for j in range(3): | |
| axes[-1, j+1].imshow(features[-1, :, :, j*3:(j+1)*3].cpu().numpy()) | |
| axes[-1, j+1].axis('off') | |
| # Hide unused columns in last row | |
| for j in range(4, 8): | |
| axes[-1, j].axis('off') | |
| plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01) | |
| # Save the plot | |
| if file_name is None: | |
| file_name = f'feat_dim{dim-9}-{dim}' | |
| if feat_type: | |
| feat_type_str = '-'.join(feat_type) | |
| file_name = file_name + f'_{feat_type_str}' | |
| save_path = os.path.join(save_dir, file_name + '.png') | |
| plt.savefig(save_path, bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| return save_path | |
| def adjust_norm(image: torch.Tensor) -> torch.Tensor: | |
| inv_normalize = tvf.Normalize( | |
| mean=[-1, -1, -1], | |
| std=[1/0.5, 1/0.5, 1/0.5] | |
| ) | |
| correct_normalize = tvf.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| image = inv_normalize(image) | |
| image = correct_normalize(image) | |
| return image | |
| def adjust_midas_norm(image: torch.Tensor) -> torch.Tensor: | |
| inv_normalize = tvf.Normalize( | |
| mean=[-1, -1, -1], | |
| std=[1/0.5, 1/0.5, 1/0.5] | |
| ) | |
| correct_normalize = tvf.Normalize( | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5] | |
| ) | |
| image = inv_normalize(image) | |
| image = correct_normalize(image) | |
| return image | |
| def adjust_clip_norm(image: torch.Tensor) -> torch.Tensor: | |
| inv_normalize = tvf.Normalize( | |
| mean=[-1, -1, -1], | |
| std=[1/0.5, 1/0.5, 1/0.5] | |
| ) | |
| correct_normalize = tvf.Normalize( | |
| mean=[0.48145466, 0.4578275, 0.40821073], | |
| std=[0.26862954, 0.26130258, 0.27577711] | |
| ) | |
| image = inv_normalize(image) | |
| image = correct_normalize(image) | |
| return image | |
| class UnNormalize(object): | |
| def __init__(self, mean, std): | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self, image): | |
| image2 = torch.clone(image) | |
| if len(image2.shape) == 4: | |
| image2 = image2.permute(1, 0, 2, 3) | |
| for t, m, s in zip(image2, self.mean, self.std): | |
| t.mul_(s).add_(m) | |
| return image2.permute(1, 0, 2, 3) | |
| norm = tvf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| midas_norm = tvf.Normalize([0.5] * 3, [0.5] * 3) | |
| midas_unnorm = UnNormalize([0.5] * 3, [0.5] * 3) | |
| def generate_iuv(B, H, W): | |
| i_coords = torch.arange(B).view(B, 1, 1, 1).expand(B, H, W, 1).float() / (B - 1) | |
| u_coords = torch.linspace(0, 1, W).view(1, 1, W, 1).expand(B, H, W, 1) | |
| v_coords = torch.linspace(0, 1, H).view(1, H, 1, 1).expand(B, H, W, 1) | |
| iuv_coords = torch.cat([i_coords, u_coords, v_coords], dim=-1) | |
| return iuv_coords | |
| class FeatureExtractor: | |
| """ | |
| Extracts and processes features from images using VFMs for per point(per pixel). | |
| Supports multiple VFM features, dimensionality reduction, feature upsampling, and visualization. | |
| Parameters: | |
| images (list): List of image info. | |
| method (str): Pointmap Init method, choose in ["dust3r", "mast3r"]. | |
| device (str): 'cuda'. | |
| feat_type (list): VFM, choose in ["dust3r", "mast3r", "dift", "dino_b16", "dinov2_b14", "radio", "clip_b16", "mae_b16", "midas_l16", "sam_base", "iuvrgb"]. | |
| feat_dim (int): PCA dimensions. | |
| img_base_path (str): Training view data directory path. | |
| model_path (str): Model path, './submodules/mast3r/checkpoints/'. | |
| vis_feat (bool): Visualize and save feature maps. | |
| vis_key (str): Feature type to visualize(only for mast3r), choose in ["decfeat", "desc"]. | |
| focal_avg (bool): Use averaging focal. | |
| """ | |
| def __init__(self, images, args, method): | |
| self.images = images | |
| self.method = method | |
| self.device = args.device | |
| self.feat_type = args.feat_type | |
| self.feat_dim = args.feat_dim | |
| self.img_base_path = args.img_base_path | |
| # self.use_featup = args.use_featup | |
| self.model_path = args.model_path | |
| self.vis_feat = args.vis_feat | |
| self.vis_key = args.vis_key | |
| self.focal_avg = args.focal_avg | |
| def get_dust3r_feat(self, **kw): | |
| model_path = os.path.join(self.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth") | |
| model = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(self.device) | |
| output = inference(kw['pairs'], model, self.device, batch_size=1) | |
| scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer) | |
| if self.vis_key: | |
| assert self.vis_key == 'decfeat', f"Expected vis_key to be 'decfeat', but got {self.vis_key}" | |
| self.vis_decfeat(kw['pairs'], output=output) | |
| # del model, output | |
| # torch.cuda.empty_cache() | |
| return scene.stacked_feat | |
| def get_mast3r_feat(self, **kw): | |
| model_path = os.path.join(self.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth") | |
| model = AsymmetricMASt3R.from_pretrained(model_path).to(self.device) | |
| cache_dir = os.path.join(self.img_base_path, "cache") | |
| if os.path.exists(cache_dir): | |
| os.system(f'rm -rf {cache_dir}') | |
| scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir, | |
| model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device, | |
| opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg, | |
| matching_conf_thr=5.) | |
| if self.vis_key: | |
| assert self.vis_key in ['decfeat', 'desc'], f"Expected vis_key to be 'decfeat' or 'desc', but got {self.vis_key}" | |
| self.vis_decfeat(kw['pairs'], model=model) | |
| # del model | |
| # torch.cuda.empty_cache() | |
| return scene.stacked_feat | |
| def get_feat(self, feat_type): | |
| """ | |
| Get features using Probe3D. | |
| """ | |
| cfg = OmegaConf.load(f"configs/backbone/{feat_type}.yaml") | |
| model = instantiate(cfg.model, output="dense", return_multilayer=False) | |
| model = model.to(self.device) | |
| if 'midas' in feat_type: | |
| image_norm = adjust_midas_norm(torch.cat([i['img'] for i in self.images])).to(self.device) | |
| # elif 'clip' in self.feat_type: | |
| # image_norm = adjust_clip_norm(torch.cat([i['img'] for i in self.images])).to(self.device) | |
| else: | |
| image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device) | |
| with torch.no_grad(): | |
| feats = model(image_norm).permute(0, 2, 3, 1) | |
| # del model | |
| # torch.cuda.empty_cache() | |
| return feats | |
| # def get_feat(self, feat_type): | |
| # """ | |
| # Get features using FeatUp. | |
| # """ | |
| # original_feat_type = feat_type | |
| # use_norm = False if 'maskclip' in feat_type else True | |
| # if 'featup' in original_feat_type: | |
| # feat_type = feat_type.split('_featup')[0] | |
| # # feat_upsampler = torch.hub.load("mhamilton723/FeatUp", feat_type, use_norm=use_norm).to(device) | |
| # feat_upsampler = torch.hub.load("/home/chenyue/.cache/torch/hub/mhamilton723_FeatUp_main/", feat_type, use_norm=use_norm, source='local').to(self.device) ## offline | |
| # image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device) | |
| # image_norm = F.interpolate(image_norm, size=(224, 224), mode='bilinear', align_corners=False) | |
| # if 'featup' in original_feat_type: | |
| # feats = feat_upsampler(image_norm).permute(0, 2, 3, 1) | |
| # else: | |
| # feats = feat_upsampler.model(image_norm).permute(0, 2, 3, 1) | |
| # return feats | |
| def get_iuvrgb(self): | |
| rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device) | |
| feats = torch.cat([generate_iuv(*rgb.shape[:-1]).to(self.device), rgb], dim=-1) | |
| return feats | |
| def get_iuv(self): | |
| rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device) | |
| feats = generate_iuv(*rgb.shape[:-1]).to(self.device) | |
| return feats | |
| def preprocess(self, feature, feat_dim, vis_feat=False, is_upsample=True): | |
| """ | |
| Preprocess features by applying PCA, upsampling, and optionally visualizing. | |
| """ | |
| if feat_dim: | |
| feature = pca(feature, feat_dim) | |
| # else: | |
| # feature_min = feature.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values | |
| # feature_max = feature.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values | |
| # feature = (feature - feature_min) / (feature_max - feature_min + 1e-6) | |
| # feature = feature - feature.mean(dim=[0,1,2], keepdim=True) | |
| torch.cuda.empty_cache() | |
| if (feature[0].shape[0:-1] != self.images[0]['true_shape'][0]).all() and is_upsample: | |
| feature = upsampler(feature, *[s for s in self.images[0]['true_shape'][0]]) | |
| print(f"Feature map size >>> height: {feature[0].shape[0]}, width: {feature[0].shape[1]}, channels: {feature[0].shape[2]}") | |
| if vis_feat: | |
| save_path = visualizer(feature, self.images, self.img_base_path, feat_type=self.feat_type) | |
| print(f"The encoder feature visualization has been saved at >>>>> {save_path}") | |
| return feature | |
| def vis_decfeat(self, pairs, **kw): | |
| """ | |
| Visualize decoder or head(only for mast3r) features. | |
| """ | |
| if 'output' in kw: | |
| output = kw['output'] | |
| else: | |
| output = inference(pairs, kw['model'], self.device, batch_size=1, verbose=False) | |
| decfeat1 = output['pred1'][self.vis_key].detach() | |
| decfeat2 = output['pred2'][self.vis_key].detach() | |
| # decfeat1 = pca(decfeat1, 9) | |
| # decfeat2 = pca(decfeat2, 9) | |
| decfeat = torch.stack([decfeat1, decfeat2], dim=1).view(-1, *decfeat1.shape[1:]) | |
| decfeat = torch.cat(torch.chunk(decfeat,2)[::-1], dim=0) | |
| decfeat = pca(decfeat, 9) | |
| if (decfeat.shape[1:-1] != self.images[0]['true_shape'][0]).all(): | |
| decfeat = upsampler(decfeat, *[s for s in self.images[0]['true_shape'][0]]) | |
| pair_images = [im for p in pairs[3:] + pairs[:3] for im in p] | |
| save_path = mv_visualizer(decfeat, pair_images, self.img_base_path, | |
| feat_type=self.feat_type, file_name=f'{self.vis_key}_pcaall_dim0-9') | |
| print(f"The decoder feature visualization has been saved at >>>>> {save_path}") | |
| def forward(self, **kw): | |
| feat_dim = self.feat_dim | |
| vis_feat = self.vis_feat and len(self.feat_type) == 1 | |
| is_upsample = len(self.feat_type) == 1 | |
| all_feats = {} | |
| for feat_type in self.feat_type: | |
| if feat_type == self.method: | |
| feats = kw['scene'].stacked_feat | |
| elif feat_type in ['dust3r', 'mast3r']: | |
| feats = getattr(self, f"get_{feat_type}_feat")(**kw) | |
| elif feat_type in ['iuv', 'iuvrgb']: | |
| feats = getattr(self, f"get_{feat_type}")() | |
| feat_dim = None | |
| else: | |
| feats = self.get_feat(feat_type) | |
| # feats = to_numpy(self.preprocess(feats)) | |
| all_feats[feat_type] = self.preprocess(feats.detach().clone(), feat_dim, vis_feat, is_upsample) | |
| if len(self.feat_type) > 1: | |
| all_feats = {k: (v - v.min()) / (v.max() - v.min()) for k, v in all_feats.items()} | |
| target_size = tuple(s // 16 for s in self.images[0]['true_shape'][0][:2]) | |
| tmp_feats = [] | |
| kickoff = [] | |
| for k, v in all_feats.items(): | |
| if k in ['iuv', 'iuvrgb']: | |
| # self.feat_dim -= v.shape[-1] | |
| kickoff.append(v) | |
| else: | |
| if v.shape[1:3] != target_size: | |
| v = F.interpolate(v.permute(0, 3, 1, 2), size=target_size, | |
| mode='bilinear', align_corners=False).permute(0, 2, 3, 1) | |
| tmp_feats.append(v) | |
| feats = self.preprocess(torch.cat(tmp_feats, dim=-1), self.feat_dim, self.vis_feat and not kickoff) | |
| if kickoff: | |
| feats = torch.cat([feats] + kickoff, dim=-1) | |
| feats = self.preprocess(feats, self.feat_dim, self.vis_feat, is_upsample=False) | |
| else: | |
| feats = all_feats[self.feat_type[0]] | |
| torch.cuda.empty_cache() | |
| return to_numpy(feats) | |
| def __call__(self, **kw): | |
| return self.forward(**kw) | |
| class InitMethod: | |
| """ | |
| Initialize pointmap and camera param via DUSt3R or MASt3R. | |
| """ | |
| def __init__(self, args): | |
| self.method = args.method | |
| self.n_views = args.n_views | |
| self.device = args.device | |
| self.img_base_path = args.img_base_path | |
| self.focal_avg = args.focal_avg | |
| self.tsdf_thresh = args.tsdf_thresh | |
| self.min_conf_thr = args.min_conf_thr | |
| if self.method == 'dust3r': | |
| self.model_path = os.path.join(args.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth") | |
| else: | |
| self.model_path = os.path.join(args.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth") | |
| def get_dust3r(self): | |
| return AsymmetricCroCo3DStereo.from_pretrained(self.model_path).to(self.device) | |
| def get_mast3r(self): | |
| return AsymmetricMASt3R.from_pretrained(self.model_path).to(self.device) | |
| def infer_dust3r(self, **kw): | |
| output = inference(kw['pairs'], kw['model'], self.device, batch_size=1) | |
| scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer) | |
| loss = compute_global_alignment(scene=scene, init="mst", niter=300, schedule='linear', lr=0.01, | |
| focal_avg=self.focal_avg, known_focal=kw.get('known_focal', None)) | |
| scene = scene.clean_pointcloud() | |
| return scene | |
| def infer_mast3r(self, **kw): | |
| cache_dir = os.path.join(self.img_base_path, "cache") | |
| if os.path.exists(cache_dir): | |
| os.system(f'rm -rf {cache_dir}') | |
| scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir, | |
| kw['model'], lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device, | |
| opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg, | |
| matching_conf_thr=5.) | |
| return scene | |
| def get_dust3r_info(self, scene): | |
| imgs = to_numpy(scene.imgs) | |
| focals = scene.get_focals() | |
| poses = to_numpy(scene.get_im_poses()) | |
| pts3d = to_numpy(scene.get_pts3d()) | |
| # pts3d = to_numpy(scene.get_planes3d()) | |
| scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0))) | |
| confidence_masks = to_numpy(scene.get_masks()) | |
| intrinsics = to_numpy(scene.get_intrinsics()) | |
| return imgs, focals, poses, intrinsics, pts3d, confidence_masks | |
| def get_mast3r_info(self, scene): | |
| imgs = to_numpy(scene.imgs) | |
| focals = scene.get_focals()[:,None] | |
| poses = to_numpy(scene.get_im_poses()) | |
| intrinsics = to_numpy(scene.intrinsics) | |
| tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh) | |
| pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True)) | |
| pts3d = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in pts3d] | |
| confidence_masks = np.array(to_numpy([c > self.min_conf_thr for c in confs])) | |
| return imgs, focals, poses, intrinsics, pts3d, confidence_masks | |
| def get_dust3r_depth(self, scene): | |
| return to_numpy(scene.get_depthmaps()) | |
| def get_mast3r_depth(self, scene): | |
| imgs = to_numpy(scene.imgs) | |
| tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh) | |
| _, depthmaps, _ = to_numpy(tsdf.get_dense_pts3d(clean_depth=True)) | |
| depthmaps = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in depthmaps] | |
| return depthmaps | |
| def get_model(self): | |
| return getattr(self, f"get_{self.method}")() | |
| def infer(self, **kw): | |
| return getattr(self, f"infer_{self.method}")(**kw) | |
| def get_info(self, scene): | |
| return getattr(self, f"get_{self.method}_info")(scene) | |
| def get_depth(self, scene): | |
| return getattr(self, f"get_{self.method}_depth")(scene) | |