Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import re | |
| from einops import rearrange | |
| from dataclasses import dataclass | |
| from torchvision import transforms | |
| from transformers import CLIPTokenizer, CLIPImageProcessor | |
| from transformers import AutoImageProcessor | |
| from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer | |
| from transformers.utils import ModelOutput | |
| from typing import Iterable, Optional, Union, List | |
| import craftsman | |
| from craftsman.utils.typing import * | |
| from .clip.modeling_clip import CLIPModel | |
| from .clip.modeling_conditional_clip import ConditionalCLIPModel | |
| from .base import BaseEmbedder, ImageType | |
| from .dino_v2.modeling_dinov2 import Dinov2Model | |
| from .dino_v2.modeling_conditional_dinov2 import ConditionalDinov2Model | |
| class CLIPEmbedOutput(ModelOutput): | |
| last_hidden_state: torch.FloatTensor = None | |
| pooler_output: torch.FloatTensor = None | |
| embeds: torch.FloatTensor = None | |
| class DINOEmbedOutput(ModelOutput): | |
| last_hidden_state: torch.FloatTensor = None | |
| pooler_output: torch.FloatTensor = None | |
| class CondEmbedder(BaseEmbedder): | |
| class Config(BaseEmbedder.Config): | |
| pretrained_model_name_or_path: Optional[str] = None # the pretrained model name or path for condition model | |
| pretrained_clip_name_or_path: Optional[str] = None # the pretrained model name or path for clip | |
| pretrained_dino_name_or_path: Optional[str] = None # the pretrained model name or path for dino | |
| pretrained_linear_proj: Optional[str] = None | |
| freeze_modulation_clip: bool = False | |
| freeze_modulation_dino: bool = False | |
| config_path: str = '' | |
| enable_gradient_checkpointing: bool = False | |
| embeds_fusion_mode: int = 1 # 0: sum | 1: concat | |
| linear_proj_init: str = "constant" | |
| text_max_length: int = 77 | |
| image_size_clip: int = 224 | |
| image_size_dino: int = 224 | |
| cfg: Config | |
| def configure(self) -> None: | |
| super().configure() | |
| # Load the CLIP model and processor | |
| if not self.cfg.encode_camera: | |
| if self.cfg.pretrained_clip_name_or_path is not None: | |
| self.clip_model: CLIPModel = CLIPModel.from_pretrained(self.cfg.pretrained_clip_name_or_path) | |
| else: | |
| self.clip_model: CLIPModel = CLIPModel(config=ConditionalCLIPModel.config_class.from_pretrained( | |
| "openai/clip-vit-large-patch14", | |
| )) | |
| if self.cfg.pretrained_dino_name_or_path is not None: | |
| self.dino_model: Dinov2Model = Dinov2Model.from_pretrained(self.cfg.pretrained_dino_name_or_path) | |
| else: | |
| self.dino_model: Dinov2Model = Dinov2Model(config=ConditionalDinov2Model.config_class.from_pretrained( | |
| "facebook/dinov2-base", | |
| )) | |
| else: | |
| if self.cfg.pretrained_clip_name_or_path == '': | |
| assert self.cfg.config_path is not None, "The config path should be provided" | |
| conditional_clip_config = ConditionalCLIPModel.config_class.from_json_file(self.cfg.config_path) | |
| conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim | |
| self.clip_model: CLIPModel = ConditionalCLIPModel(conditional_clip_config) | |
| else: | |
| # clip | |
| conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained( | |
| self.cfg.pretrained_clip_name_or_path, | |
| ) | |
| conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim | |
| self.clip_model: CLIPModel = ConditionalCLIPModel.from_pretrained( | |
| self.cfg.pretrained_clip_name_or_path, | |
| vision_config=conditional_clip_config.vision_config | |
| ) | |
| # dino | |
| conditional_vit_config = ConditionalDinov2Model.config_class.from_pretrained( | |
| self.cfg.pretrained_dino_name_or_path, | |
| ) | |
| conditional_vit_config.modulation_dim = self.cfg.camera_embeds_dim | |
| self.dino_model: ConditionalDinov2Model = ConditionalDinov2Model.from_pretrained( | |
| self.cfg.pretrained_dino_name_or_path, | |
| config=conditional_vit_config | |
| ) | |
| self.image_preprocess_clip = CLIPImageProcessor() | |
| self.image_preprocess_dino = AutoImageProcessor.from_pretrained( | |
| self.cfg.pretrained_dino_name_or_path if self.cfg.pretrained_dino_name_or_path is not None else "facebook/dinov2-base", | |
| ) | |
| self.transform_clip= transforms.Compose( | |
| [ | |
| transforms.Resize(self.cfg.image_size_clip, transforms.InterpolationMode.BICUBIC, antialias=True), | |
| transforms.CenterCrop(self.cfg.image_size_clip), # crop a (224, 224) square | |
| transforms.Normalize( | |
| mean=[0.48145466, 0.4578275, 0.40821073], | |
| std=[0.26862954, 0.26130258, 0.27577711], | |
| ), | |
| ] | |
| ) | |
| self.transform_dino = transforms.Compose( | |
| [ | |
| transforms.Resize(self.cfg.image_size_dino, transforms.InterpolationMode.BICUBIC, antialias=True), | |
| transforms.CenterCrop(self.cfg.image_size_dino), # crop a (224, 224) square | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ] | |
| ) | |
| if self.cfg.enable_gradient_checkpointing: | |
| self.dino_model.encoder.gradient_checkpointing = True | |
| if self.cfg.zero_uncond_embeds: | |
| self.empty_image_embeds_clip = torch.zeros((self.cfg.n_views, 257, 1024)).detach() | |
| self.empty_image_embeds_dino = torch.zeros((self.cfg.n_views, 257, 1024)).detach() | |
| self.empty_image_embeds = torch.cat([self.empty_image_embeds_clip, self.empty_image_embeds_dino], dim=1) | |
| else: | |
| if self.cfg.encode_camera: | |
| self.empty_image_embeds_clip = self.encode_image_clip(torch.zeros(self.cfg.n_views, self.cfg.image_size_clip, self.cfg.image_size_clip, 3), self.cameras[:self.cfg.n_views]).detach() | |
| self.empty_image_embeds_dino = self.encode_image_dino(torch.zeros(self.cfg.n_views, self.cfg.image_size_clip, self.cfg.image_size_clip, 3), self.cameras[:self.cfg.n_views]).detach() | |
| self.empty_image_embeds = torch.cat([self.empty_image_embeds_clip, self.empty_image_embeds_dino], dim=1) | |
| else: | |
| self.empty_image_embeds_clip = self.encode_image_clip(torch.zeros(self.cfg.n_views, self.cfg.image_size_dino, self.cfg.image_size_dino, 3)).detach() | |
| self.empty_image_embeds_dino = self.encode_image_dino(torch.zeros(self.cfg.n_views, self.cfg.image_size_dino, self.cfg.image_size_dino, 3)).detach() | |
| self.empty_image_embeds = torch.cat([self.empty_image_embeds_clip, self.empty_image_embeds_dino], dim=1) | |
| # Freeze the clip model parameters | |
| self.clip_model.eval() | |
| for k, p in self.clip_model.named_parameters(): | |
| ks = k.split('.') | |
| if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation_clip: | |
| p.requires_grad_(not self.cfg.freeze_modulation_clip) | |
| else: | |
| p.requires_grad_(False) | |
| # freeze the dino model parameters | |
| self.dino_model.eval() | |
| for k, p in self.dino_model.named_parameters(): | |
| ks = k.split('.') | |
| if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation_dino: | |
| p.requires_grad_(not self.cfg.freeze_modulation_dino) | |
| else: | |
| p.requires_grad_(False) | |
| self.linear_proj = nn.Linear(768, 1024, bias=False) | |
| if self.cfg.linear_proj_init == "constant": | |
| nn.init.constant_(self.linear_proj.weight, 0) | |
| elif self.cfg.linear_proj_init == "xavier": | |
| nn.init.xavier_uniform_(self.linear_proj.weight) | |
| else: | |
| raise ValueError | |
| if self.cfg.pretrained_model_name_or_path is not None: | |
| print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") | |
| ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu")['state_dict'] | |
| pretrained_model_ckpt = {} | |
| for k, v in ckpt.items(): | |
| if k.startswith('condition.'): | |
| pretrained_model_ckpt[k.replace('condition.', '')] = v | |
| self.load_state_dict(pretrained_model_ckpt, strict=False) | |
| def encode_image_clip(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: | |
| camera_embeds = None | |
| if isinstance(images, (np.ndarray, torch.Tensor)): # for training process | |
| assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]" | |
| do_rescale = False | |
| if self.cfg.encode_camera: | |
| assert cameras is not None, "The cameras should be provided" | |
| camera_embeds = self.encode_camera(cameras) | |
| pixel_values = self.transform_clip(images.permute(0, 3, 1, 2)) | |
| else: # for inference process | |
| do_rescale = True | |
| if self.cfg.encode_camera: | |
| if cameras is None: | |
| bs = len(images) // self.cfg.n_views | |
| cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.clip_model.device) | |
| camera_embeds = self.encode_camera(cameras) | |
| pixel_values = self.image_preprocess_clip.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values | |
| if force_none_camera_embeds: | |
| camera_embeds = None | |
| if pixel_values.ndim == 4: | |
| pixel_values = pixel_values.unsqueeze(1) | |
| if camera_embeds is not None: | |
| camera_embeds = camera_embeds.unsqueeze(1) | |
| if self.cfg.encode_camera and camera_embeds is not None: | |
| vision_outputs = self.clip_model.vision_model( | |
| pixel_values=rearrange(pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W"), | |
| condition=rearrange(camera_embeds, "B N C -> (B N) C") | |
| ) | |
| else: | |
| vision_outputs = self.clip_model.vision_model( | |
| pixel_values=rearrange(pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W"), | |
| ) | |
| if return_dict: | |
| # clip | |
| pooler_output = vision_outputs[1] # pooled_output | |
| image_features = self.clip_model.visual_projection(pooler_output) | |
| clip_embeds = vision_outputs.last_hidden_state | |
| clip_embeds_dict = CLIPEmbedOutput( | |
| last_hidden_state=clip_embeds, | |
| pooler_output=pooler_output, | |
| embeds=image_features | |
| ) | |
| return clip_embeds_dict | |
| else: | |
| return vision_outputs.last_hidden_state | |
| def encode_image_dino(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: | |
| camera_embeds = None | |
| if isinstance(images, (np.ndarray, torch.Tensor)): # for training process | |
| assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]" | |
| do_rescale = False | |
| if self.cfg.encode_camera: | |
| assert cameras is not None, "The cameras should be provided" | |
| camera_embeds = self.encode_camera(cameras) | |
| pixel_values = self.transform_dino(images.permute(0, 3, 1, 2)) | |
| else: # for inference process | |
| do_rescale = True | |
| if self.cfg.encode_camera: | |
| if cameras is None: | |
| bs = len(images) // self.cfg.n_views | |
| cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.dino_model.device) | |
| camera_embeds = self.encode_camera(cameras) | |
| pixel_values = self.image_preprocess_dino.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values | |
| if force_none_camera_embeds: | |
| camera_embeds = None | |
| if pixel_values.ndim == 4: | |
| pixel_values = pixel_values.unsqueeze(1) | |
| if camera_embeds is not None: | |
| camera_embeds = camera_embeds.unsqueeze(1) | |
| if self.cfg.encode_camera and camera_embeds is not None: | |
| vision_outputs = self.dino_model( | |
| rearrange(pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"), | |
| condition=rearrange(camera_embeds, "B N C -> (B N) C"), | |
| ) | |
| else: | |
| vision_outputs = self.dino_model( | |
| rearrange(pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"), | |
| ) | |
| if return_dict: | |
| # dino | |
| dino_embeds_dict = DINOEmbedOutput( | |
| last_hidden_state=vision_outputs.last_hidden_state, | |
| pooler_output=vision_outputs.pooler_output, | |
| ) | |
| return dino_embeds_dict | |
| else: | |
| return vision_outputs.last_hidden_state | |
| def encode_image(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: | |
| clip_embeds = self.encode_image_clip(images, cameras) | |
| dino_embeds = self.encode_image_dino(images, cameras) | |
| dino_embeds = self.linear_proj(dino_embeds) | |
| visual_embeds = torch.cat([clip_embeds, dino_embeds], dim=1) | |
| return visual_embeds | |