Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForVision2Seq, AutoProcessor | |
| from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration | |
| from transformers.tools import PipelineTool | |
| from transformers.tools.base import get_default_device | |
| from transformers.utils import requires_backends | |
| class InstructBLIPImageQuestionAnsweringTool(PipelineTool): | |
| #default_checkpoint = "Salesforce/blip2-opt-2.7b" | |
| #default_checkpoint = "Salesforce/instructblip-flan-t5-xl" | |
| #default_checkpoint = "Salesforce/instructblip-vicuna-7b" | |
| default_checkpoint = "Salesforce/instructblip-vicuna-13b" | |
| description = ( | |
| "This is a tool that answers a question about an image. It takes an input named `image` which should be the " | |
| "image containing the information, as well as a `question` which should be the question in English. It " | |
| "returns a text that is the answer to the question." | |
| ) | |
| name = "image_qa" | |
| pre_processor_class = AutoProcessor | |
| model_class = AutoModelForVision2Seq | |
| inputs = ["image", "text"] | |
| outputs = ["text"] | |
| def __init__(self, *args, **kwargs): | |
| requires_backends(self, ["vision"]) | |
| super().__init__(*args, **kwargs) | |
| def setup(self): | |
| """ | |
| Instantiates the `pre_processor`, `model` and `post_processor` if necessary. | |
| """ | |
| if isinstance(self.pre_processor, str): | |
| self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) | |
| if isinstance(self.model, str): | |
| self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs, load_in_4bit=True, torch_dtype=torch.float16) | |
| if self.post_processor is None: | |
| self.post_processor = self.pre_processor | |
| elif isinstance(self.post_processor, str): | |
| self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) | |
| if self.device is None: | |
| if self.device_map is not None: | |
| self.device = list(self.model.hf_device_map.values())[0] | |
| else: | |
| self.device = get_default_device() | |
| self.is_initialized = True | |
| def encode(self, image, question: str): | |
| return self.pre_processor(images=image, text=question, return_tensors="pt").to(device="cuda", dtype=torch.float16) | |
| def forward(self, inputs): | |
| outputs = self.model.generate( | |
| **inputs, | |
| num_beams=5, | |
| max_new_tokens=256, | |
| min_length=1, | |
| top_p=0.9, | |
| repetition_penalty=1.5, | |
| length_penalty=1.0, | |
| temperature=0.7, | |
| ) | |
| return outputs | |
| def decode(self, outputs): | |
| return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() | |