File size: 4,800 Bytes
3bfd811
 
 
 
 
 
 
 
 
 
 
 
7318bea
3bfd811
7318bea
3bfd811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7318bea
3bfd811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers
import numpy as np
import os
from transformers import AutoImageProcessor, AutoModel
import math
class DinoFeatureModule(nn.Module):
    def __init__(self, model_id: str = "facebook/dinov2-giant"):
        super(DinoFeatureModule, self).__init__()
        dtype = torch.float32
        self.model_id = model_id
        self.dino = AutoModel.from_pretrained(
            self.model_id,
            torch_dtype=dtype
        )
        
      
        self.dino.eval()
        for param in self.dino.parameters():
            param.requires_grad = False
        
       
        frozen = all(not p.requires_grad for p in self.dino.parameters())
        assert frozen, "DINOv2 model parameters are not completely frozen!"
        
        
        self.shallow_dim = 1536  
        self.mid_dim = 1536      
        self.deep_dim = 1536     
        
    def get_dino_features(self, x):
        with torch.no_grad(): 
            outputs = self.dino(x, output_hidden_states=True)
            hidden_states = outputs.hidden_states
           
            _, _, H, W = x.shape
            aspect_ratio = W / H
          
            shallow_feat1 = hidden_states[7]  
            shallow_feat2 = hidden_states[15]
            mid_feat1 = hidden_states[20]
            mid_feat2 = hidden_states[22]
            deep_feat1 = hidden_states[33]
            deep_feat2 = hidden_states[39]

            def reshape_features(feat):
                feat = feat[:, 1:, :]
                B, N, C = feat.shape
                
                h = int(math.sqrt(N / aspect_ratio))
                w = int(N / h)
                
               
                if(aspect_ratio > 1):
                    if h * w > N:
                        h -= 1
                        w = N // h
                    if h * w < N:
                        h += 1
                        w = N // h
                else: 
                    if h * w > N:
                        w -= 1
                        h = N // w
                    if h * w < N:
                        w += 1
                        h = N // w
                
                
                assert h * w == N, f"Dimensions mismatch: {h}*{w} != {N}"
                
                
                feat = feat.reshape(B, h, w, C).permute(0, 3, 1, 2)
                return feat

          
            shallow_feat1 = reshape_features(shallow_feat1).float()
            mid_feat1 = reshape_features(mid_feat1).float()
            deep_feat1 = reshape_features(deep_feat1).float()
            shallow_feat2 = reshape_features(shallow_feat2).float()
            mid_feat2 = reshape_features(mid_feat2).float()
            deep_feat2 = reshape_features(deep_feat2).float()

            return shallow_feat1, mid_feat1, deep_feat1, shallow_feat2, mid_feat2, deep_feat2
    
    def check_image_size(self, x):
        _, _, h, w = x.size()
        pad_size = 16
        mod_pad_h = (pad_size - h % pad_size) % pad_size
        mod_pad_w = (pad_size - w % pad_size) % pad_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
        return x

    def forward(self, inp_img):
        
        device = inp_img.device
        
        mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)


        denormalized_img = inp_img * std + mean
        denormalized_img = self.check_image_size(denormalized_img)
        h_denormalized, w_denormalized = denormalized_img.shape[2], denormalized_img.shape[3]
        # To ensure minimal changes and maintain code generality, the image size is directly scaled here to guarantee spatial alignment.

        target_h = (h_denormalized // 8) * 14  
        target_w = (w_denormalized // 8) * 14  
       
        shortest_edge = min(target_h, target_w)
        processor = AutoImageProcessor.from_pretrained(
            self.model_id,
            local_files_only=False,
            do_rescale=False,
            do_center_crop=False,  
            use_fast=True,
            size={"shortest_edge": shortest_edge}
        )
        
        inputs = processor(
            images=denormalized_img,  
            return_tensors="pt"
        ).to(device)
            
       
        shallow_feat1, mid_feat1, deep_feat1, shallow_feat2, mid_feat2, deep_feat2 = self.get_dino_features(inputs['pixel_values'])
        
        dino_features = {
            'shallow_feat1': shallow_feat1,
            'mid_feat1': mid_feat1,
            'deep_feat1': deep_feat1,
            'shallow_feat2': shallow_feat2,
            'mid_feat2': mid_feat2,
            'deep_feat2': deep_feat2
        }
        
        return dino_features