Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from huggingface_hub import hf_hub_download | |
| from transformers.models.vit.modeling_vit import ViTModel | |
| from ...utils import BaseModule | |
| class DINOSingleImageTokenizer(BaseModule): | |
| class Config(BaseModule.Config): | |
| pretrained_model_name_or_path: str = "facebook/dino-vitb16" | |
| enable_gradient_checkpointing: bool = False | |
| cfg: Config | |
| def configure(self) -> None: | |
| self.model: ViTModel = ViTModel( | |
| ViTModel.config_class.from_pretrained( | |
| hf_hub_download( | |
| repo_id=self.cfg.pretrained_model_name_or_path, | |
| filename="config.json", | |
| ) | |
| ) | |
| ) | |
| if self.cfg.enable_gradient_checkpointing: | |
| self.model.encoder.gradient_checkpointing = True | |
| self.register_buffer( | |
| "image_mean", | |
| torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), | |
| persistent=False, | |
| ) | |
| self.register_buffer( | |
| "image_std", | |
| torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), | |
| persistent=False, | |
| ) | |
| def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor: | |
| packed = False | |
| if images.ndim == 4: | |
| packed = True | |
| images = images.unsqueeze(1) | |
| batch_size, n_input_views = images.shape[:2] | |
| images = (images - self.image_mean) / self.image_std | |
| out = self.model( | |
| rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True | |
| ) | |
| local_features, global_features = out.last_hidden_state, out.pooler_output | |
| local_features = local_features.permute(0, 2, 1) | |
| local_features = rearrange( | |
| local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size | |
| ) | |
| if packed: | |
| local_features = local_features.squeeze(1) | |
| return local_features | |
| def detokenize(self, *args, **kwargs): | |
| raise NotImplementedError | |