import torch import torch.nn as nn import torch.nn.functional as F import numbers import numpy as np import os from transformers import AutoImageProcessor, AutoModel import math class DinoFeatureModule(nn.Module): def __init__(self, model_id: str = "facebook/dinov2-giant"): super(DinoFeatureModule, self).__init__() dtype = torch.float32 self.model_id = model_id self.dino = AutoModel.from_pretrained( self.model_id, torch_dtype=dtype ) self.dino.eval() for param in self.dino.parameters(): param.requires_grad = False frozen = all(not p.requires_grad for p in self.dino.parameters()) assert frozen, "DINOv2 model parameters are not completely frozen!" self.shallow_dim = 1536 self.mid_dim = 1536 self.deep_dim = 1536 def get_dino_features(self, x): with torch.no_grad(): outputs = self.dino(x, output_hidden_states=True) hidden_states = outputs.hidden_states _, _, H, W = x.shape aspect_ratio = W / H shallow_feat1 = hidden_states[7] shallow_feat2 = hidden_states[15] mid_feat1 = hidden_states[20] mid_feat2 = hidden_states[22] deep_feat1 = hidden_states[33] deep_feat2 = hidden_states[39] def reshape_features(feat): feat = feat[:, 1:, :] B, N, C = feat.shape h = int(math.sqrt(N / aspect_ratio)) w = int(N / h) if(aspect_ratio > 1): if h * w > N: h -= 1 w = N // h if h * w < N: h += 1 w = N // h else: if h * w > N: w -= 1 h = N // w if h * w < N: w += 1 h = N // w assert h * w == N, f"Dimensions mismatch: {h}*{w} != {N}" feat = feat.reshape(B, h, w, C).permute(0, 3, 1, 2) return feat shallow_feat1 = reshape_features(shallow_feat1).float() mid_feat1 = reshape_features(mid_feat1).float() deep_feat1 = reshape_features(deep_feat1).float() shallow_feat2 = reshape_features(shallow_feat2).float() mid_feat2 = reshape_features(mid_feat2).float() deep_feat2 = reshape_features(deep_feat2).float() return shallow_feat1, mid_feat1, deep_feat1, shallow_feat2, mid_feat2, deep_feat2 def check_image_size(self, x): _, _, h, w = x.size() pad_size = 16 mod_pad_h = (pad_size - h % pad_size) % pad_size mod_pad_w = (pad_size - w % pad_size) % pad_size x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') return x def forward(self, inp_img): device = inp_img.device mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) denormalized_img = inp_img * std + mean denormalized_img = self.check_image_size(denormalized_img) h_denormalized, w_denormalized = denormalized_img.shape[2], denormalized_img.shape[3] # To ensure minimal changes and maintain code generality, the image size is directly scaled here to guarantee spatial alignment. target_h = (h_denormalized // 8) * 14 target_w = (w_denormalized // 8) * 14 shortest_edge = min(target_h, target_w) processor = AutoImageProcessor.from_pretrained( self.model_id, local_files_only=False, do_rescale=False, do_center_crop=False, use_fast=True, size={"shortest_edge": shortest_edge} ) inputs = processor( images=denormalized_img, return_tensors="pt" ).to(device) shallow_feat1, mid_feat1, deep_feat1, shallow_feat2, mid_feat2, deep_feat2 = self.get_dino_features(inputs['pixel_values']) dino_features = { 'shallow_feat1': shallow_feat1, 'mid_feat1': mid_feat1, 'deep_feat1': deep_feat1, 'shallow_feat2': shallow_feat2, 'mid_feat2': mid_feat2, 'deep_feat2': deep_feat2 } return dino_features