Spaces:
Runtime error
Runtime error
| from models.med import BertConfig, BertModel | |
| from transformers import BertTokenizer | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from models.blip import create_vit, init_tokenizer, load_checkpoint | |
| class BLIP_ITM(nn.Module): | |
| def __init__(self, | |
| med_config = 'configs/med_config.json', | |
| image_size = 384, | |
| vit = 'base', | |
| vit_grad_ckpt = False, | |
| vit_ckpt_layer = 0, | |
| embed_dim = 256, | |
| ): | |
| """ | |
| Args: | |
| med_config (str): path for the mixture of encoder-decoder model's configuration file | |
| image_size (int): input image size | |
| vit (str): model size of vision transformer | |
| """ | |
| super().__init__() | |
| self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) | |
| self.tokenizer = init_tokenizer() | |
| med_config = BertConfig.from_json_file(med_config) | |
| med_config.encoder_width = vision_width | |
| self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) | |
| text_width = self.text_encoder.config.hidden_size | |
| self.vision_proj = nn.Linear(vision_width, embed_dim) | |
| self.text_proj = nn.Linear(text_width, embed_dim) | |
| self.itm_head = nn.Linear(text_width, 2) | |
| def forward(self, image, caption, match_head='itm'): | |
| image_embeds = self.visual_encoder(image) | |
| image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) | |
| text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, | |
| return_tensors="pt").to(image.device) | |
| if match_head=='itm': | |
| output = self.text_encoder(text.input_ids, | |
| attention_mask = text.attention_mask, | |
| encoder_hidden_states = image_embeds, | |
| encoder_attention_mask = image_atts, | |
| return_dict = True, | |
| ) | |
| itm_output = self.itm_head(output.last_hidden_state[:,0,:]) | |
| return itm_output | |
| elif match_head=='itc': | |
| text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, | |
| return_dict = True, mode = 'text') | |
| image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) | |
| text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) | |
| sim = image_feat @ text_feat.t() | |
| return sim | |
| def blip_itm(pretrained='',**kwargs): | |
| model = BLIP_ITM(**kwargs) | |
| if pretrained: | |
| model,msg = load_checkpoint(model,pretrained) | |
| assert(len(msg.missing_keys)==0) | |
| return model | |