Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| SPDX-License-Identifier: MIT | |
| """ | |
| import numpy as np | |
| import torch | |
| from PIL import ImageOps | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import resize | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| class DolphinProcessor: | |
| def __init__( | |
| self, | |
| dp_config, | |
| tokenizer, | |
| **kwargs, | |
| ) -> None: | |
| self.tokenizer = tokenizer | |
| transform_args = kwargs.get("transform_args", {}) | |
| self.max_length = transform_args.get("max_length", 2048) | |
| self.input_size = transform_args.get("input_size", [896, 896]) # height, width | |
| if isinstance(self.input_size, int): | |
| self.input_size = [self.input_size, self.input_size] | |
| try: | |
| self.answer_start_token = self.tokenizer._prompt_end_token | |
| except AttributeError as err: | |
| print('No answer_start_token found, use "" instead') | |
| self.answer_start_token = "" | |
| self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True) | |
| self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True) | |
| self.transform = transforms.Compose( | |
| [transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)] | |
| ) | |
| def process_prompt_for_inference(self, prompt): | |
| prompt = prompt.replace("<image>\n", "") | |
| if not prompt.startswith("<s>"): | |
| prompt = "<s>" + prompt | |
| message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)] | |
| ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32)) | |
| return ids.unsqueeze(0) | |
| def process_image_for_inference(self, image, return_img_size=False): | |
| image = resize(image, min(self.input_size)) | |
| image.thumbnail((self.input_size[1], self.input_size[0])) | |
| origin_w, origin_h = image.size | |
| delta_width = self.input_size[1] - image.width | |
| delta_height = self.input_size[0] - image.height | |
| pad_width = delta_width // 2 | |
| pad_height = delta_height // 2 | |
| padding = ( | |
| pad_width, | |
| pad_height, | |
| delta_width - pad_width, | |
| delta_height - pad_height, | |
| ) | |
| image = ImageOps.expand(image, padding) | |
| if return_img_size: | |
| return self.transform(image).unsqueeze(0), (origin_w, origin_h) | |
| return self.transform(image).unsqueeze(0) | |