Spaces:
Runtime error
Runtime error
Update modeling_vision.py
Browse files- modeling_vision.py +6 -6
modeling_vision.py
CHANGED
|
@@ -28,14 +28,14 @@ class VisionEncoder(nn.Module):
|
|
| 28 |
return self.proj(image_features)
|
| 29 |
|
| 30 |
def grid_pooling(self, image_features):
|
|
|
|
|
|
|
| 31 |
if self.args.grid_size == -1: # no grid pooling
|
| 32 |
-
return image_features
|
| 33 |
if self.args.grid_size == 0: # take cls token
|
| 34 |
-
return
|
| 35 |
if self.args.grid_size == 1: # global avg pooling
|
| 36 |
-
return image_features.mean(dim=1, keepdim=True)
|
| 37 |
-
cls_features = image_features[:, 0:1, :]
|
| 38 |
-
image_features = image_features[:, 1:, :] #drop cls token
|
| 39 |
B, L, D = image_features.shape
|
| 40 |
H_or_W = int(L**0.5)
|
| 41 |
image_features = image_features.view(B, H_or_W, H_or_W, D)
|
|
@@ -45,4 +45,4 @@ class VisionEncoder(nn.Module):
|
|
| 45 |
kernel_size=grid_stride,
|
| 46 |
stride=grid_stride)
|
| 47 |
image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
|
| 48 |
-
return torch.cat((
|
|
|
|
| 28 |
return self.proj(image_features)
|
| 29 |
|
| 30 |
def grid_pooling(self, image_features):
|
| 31 |
+
cls_features = image_features[:, 0:1, :]
|
| 32 |
+
image_features = image_features[:, 1:, :] #drop cls token
|
| 33 |
if self.args.grid_size == -1: # no grid pooling
|
| 34 |
+
return torch.cat((image_features, cls_features), dim=1)
|
| 35 |
if self.args.grid_size == 0: # take cls token
|
| 36 |
+
return cls_features
|
| 37 |
if self.args.grid_size == 1: # global avg pooling
|
| 38 |
+
return torch.cat((image_features.mean(dim=1, keepdim=True), cls_features), dim=1)
|
|
|
|
|
|
|
| 39 |
B, L, D = image_features.shape
|
| 40 |
H_or_W = int(L**0.5)
|
| 41 |
image_features = image_features.view(B, H_or_W, H_or_W, D)
|
|
|
|
| 45 |
kernel_size=grid_stride,
|
| 46 |
stride=grid_stride)
|
| 47 |
image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
|
| 48 |
+
return torch.cat((image_features, cls_features), dim=1)
|