| from typing import Any, Dict, Optional | |
| import PIL | |
| import torch | |
| import PIL | |
| import torch | |
| from typing import Dict | |
| from io import BytesIO | |
| from transformers import SiglipImageProcessor | |
| from sentence_transformers.models import Transformer as BaseTransformer | |
| class MultiModalTransformer(BaseTransformer): | |
| def __init__( | |
| self, | |
| model_name_or_path: str, | |
| cache_dir: Optional[str] = None, | |
| tokenizer_args: Optional[Dict[str, Any]] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(model_name_or_path, **kwargs) | |
| if tokenizer_args is None: | |
| tokenizer_args = {} | |
| self.processor = SiglipImageProcessor.from_pretrained( | |
| model_name_or_path, cache_dir=cache_dir, **tokenizer_args | |
| ) | |
| def forward( | |
| self, features: dict[str, torch.Tensor], **kwargs | |
| ) -> dict[str, torch.Tensor]: | |
| trans_features = { | |
| "input_ids": features["input_ids"], | |
| "attention_mask": features["attention_mask"], | |
| } | |
| if "pixel_values" in features: | |
| trans_features["pixel_values"] = features["pixel_values"].to( | |
| self.auto_model.dtype | |
| ) | |
| sentence_embedding = self.auto_model(**trans_features, **kwargs)[ | |
| "sentence_embedding" | |
| ] | |
| features.update({"sentence_embedding": sentence_embedding}) | |
| return features | |
| def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]: | |
| img_start_token = "<|jasper_img_start|>" | |
| img_token = "<|jasper_img_token|>" | |
| img_end_token = "<|jasper_img_end|>" | |
| num_img_tokens = 300 | |
| def process_text_item(item): | |
| if isinstance(item, str): | |
| return item, [] | |
| text, images = "", [] | |
| for sub_item in item: | |
| if sub_item["type"] == "text": | |
| text += sub_item["content"] | |
| elif sub_item["type"] == "image_bytes": | |
| text += img_start_token + img_token * num_img_tokens + img_end_token | |
| images.append( | |
| PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB") | |
| ) | |
| elif sub_item["type"] == "image_path": | |
| text += img_start_token + img_token * num_img_tokens + img_end_token | |
| images.append(PIL.Image.open(sub_item["content"]).convert("RGB")) | |
| else: | |
| raise ValueError(f"unknown data type {sub_item['type']}") | |
| return text, images | |
| all_texts, all_images = [], [] | |
| for item in texts: | |
| text, images = process_text_item(item) | |
| all_texts.append(text) | |
| all_images.extend(images) | |
| ipt = self.tokenizer( | |
| all_texts, | |
| padding="longest", | |
| truncation=True, | |
| max_length=self.max_seq_length, | |
| return_tensors="pt", | |
| ) | |
| if all_images: | |
| ipt["pixel_values"] = self.processor( | |
| images=all_images, return_tensors="pt" | |
| )["pixel_values"] | |
| return ipt | |