Spaces:
Runtime error
Runtime error
| import os | |
| import argparse | |
| import pdb | |
| import time | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import easyocr | |
| import copy | |
| import time | |
| from caption_anything.captioner import build_captioner, BaseCaptioner | |
| from caption_anything.segmenter import build_segmenter, build_segmenter_densecap | |
| from caption_anything.text_refiner import build_text_refiner | |
| from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image, get_image_shape | |
| from caption_anything.utils.utils import mask_painter_foreground_all, mask_painter, xywh_to_x1y1x2y2, image_resize | |
| from caption_anything.utils.densecap_painter import draw_bbox | |
| class CaptionAnything: | |
| def __init__(self, args, api_key="", captioner=None, segmenter=None, ocr_reader=None, text_refiner=None): | |
| self.args = args | |
| self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner | |
| self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter | |
| self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model) | |
| self.ocr_lang = ["ch_tra", "en"] | |
| self.ocr_reader = ocr_reader if ocr_reader is not None else easyocr.Reader(self.ocr_lang) | |
| self.text_refiner = None | |
| if not args.disable_gpt: | |
| if text_refiner is not None: | |
| self.text_refiner = text_refiner | |
| elif api_key != "": | |
| self.init_refiner(api_key) | |
| self.require_caption_prompt = args.captioner == 'blip2' | |
| def image_embedding(self): | |
| return self.segmenter.image_embedding | |
| def image_embedding(self, image_embedding): | |
| self.segmenter.image_embedding = image_embedding | |
| def original_size(self): | |
| return self.segmenter.predictor.original_size | |
| def original_size(self, original_size): | |
| self.segmenter.predictor.original_size = original_size | |
| def input_size(self): | |
| return self.segmenter.predictor.input_size | |
| def input_size(self, input_size): | |
| self.segmenter.predictor.input_size = input_size | |
| def setup(self, image_embedding, original_size, input_size, is_image_set): | |
| self.image_embedding = image_embedding | |
| self.original_size = original_size | |
| self.input_size = input_size | |
| self.segmenter.predictor.is_image_set = is_image_set | |
| def init_refiner(self, api_key): | |
| try: | |
| self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key) | |
| self.text_refiner.llm('hi') # test | |
| except: | |
| self.text_refiner = None | |
| print('OpenAI GPT is not available') | |
| def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False, verbose=False, is_densecap=False, args={}): | |
| # segment with prompt | |
| print("CA prompt: ", prompt, "CA controls", controls) | |
| is_seg_everything = 'everything' in prompt['prompt_type'] | |
| args['seg_crop_mode'] = args.get('seg_crop_mode', self.args.seg_crop_mode) | |
| args['clip_filter'] = args.get('clip_filter', self.args.clip_filter) | |
| args['disable_regular_box'] = args.get('disable_regular_box', self.args.disable_regular_box) | |
| args['context_captions'] = args.get('context_captions', self.args.context_captions) | |
| args['enable_reduce_tokens'] = args.get('enable_reduce_tokens', self.args.enable_reduce_tokens) | |
| args['enable_morphologyex'] = args.get('enable_morphologyex', self.args.enable_morphologyex) | |
| args['topN'] = args.get('topN', 10) if is_seg_everything else 1 | |
| args['min_mask_area'] = args.get('min_mask_area', 0) | |
| if not is_densecap: | |
| seg_results = self.segmenter.inference(image, prompt) | |
| else: | |
| seg_results = self.segmenter_densecap.inference(image, prompt) | |
| seg_masks, seg_bbox, seg_area = seg_results if is_seg_everything else (seg_results, None, None) | |
| if args['topN'] > 1: # sort by area | |
| samples = list(zip(*[seg_masks, seg_bbox, seg_area])) | |
| # top_samples = sorted(samples, key=lambda x: x[2], reverse=True) | |
| # seg_masks, seg_bbox, seg_area = list(zip(*top_samples)) | |
| samples = list(filter(lambda x: x[2] > args['min_mask_area'], samples)) | |
| samples = samples[:args['topN']] | |
| seg_masks, seg_bbox, seg_area = list(zip(*samples)) | |
| out_list = [] | |
| for i, seg_mask in enumerate(seg_masks): | |
| if args['enable_morphologyex']: | |
| seg_mask = 255 * seg_mask.astype(np.uint8) | |
| seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1) | |
| seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8)) | |
| seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8)) | |
| seg_mask = seg_mask[:, :, 0] > 0 | |
| seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.) | |
| mask_save_path = None | |
| if verbose: | |
| mask_save_path = f'result/mask_{time.time()}.png' | |
| if not os.path.exists(os.path.dirname(mask_save_path)): | |
| os.makedirs(os.path.dirname(mask_save_path)) | |
| if seg_mask_img.mode != 'RGB': | |
| seg_mask_img = seg_mask_img.convert('RGB') | |
| seg_mask_img.save(mask_save_path) | |
| print('seg_mask path: ', mask_save_path) | |
| print("seg_mask.shape: ", seg_mask.shape) | |
| # captioning with mask | |
| if args['enable_reduce_tokens']: | |
| result = self.captioner.inference_with_reduced_tokens(image, seg_mask, | |
| crop_mode=args['seg_crop_mode'], | |
| filter=args['clip_filter'], | |
| disable_regular_box=args['disable_regular_box'], | |
| verbose=verbose, | |
| caption_args=args) | |
| else: | |
| result = self.captioner.inference_seg(image, seg_mask, | |
| crop_mode=args['seg_crop_mode'], | |
| filter=args['clip_filter'], | |
| disable_regular_box=args['disable_regular_box'], | |
| verbose=verbose, | |
| caption_args=args) | |
| caption = result.get('caption', None) | |
| crop_save_path = result.get('crop_save_path', None) | |
| # refining with TextRefiner | |
| context_captions = [] | |
| if args['context_captions']: | |
| context_captions.append(self.captioner.inference(image)['caption']) | |
| if not disable_gpt and self.text_refiner is not None: | |
| refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, | |
| enable_wiki=enable_wiki) | |
| else: | |
| refined_caption = {'raw_caption': caption} | |
| out = {'generated_captions': refined_caption, | |
| 'crop_save_path': crop_save_path, | |
| 'mask_save_path': mask_save_path, | |
| 'mask': seg_mask_img, | |
| 'bbox': seg_bbox[i] if seg_bbox is not None else None, | |
| 'area': seg_area[i] if seg_area is not None else None, | |
| 'context_captions': context_captions, | |
| 'ppl_score': result.get('ppl_score', -100.), | |
| 'clip_score': result.get('clip_score', 0.) | |
| } | |
| out_list.append(out) | |
| return out_list | |
| def parse_dense_caption(self, image, topN=10, reference_caption=[], verbose=False): | |
| width, height = get_image_shape(image) | |
| prompt = {'prompt_type': ['everything']} | |
| densecap_args = { | |
| 'return_ppl': True, | |
| 'clip_filter': True, | |
| 'reference_caption': reference_caption, | |
| 'text_prompt': "", # 'Question: what does the image show? Answer:' | |
| 'seg_crop_mode': 'w_bg', | |
| # 'text_prompt': "", | |
| # 'seg_crop_mode': 'wo_bg', | |
| 'disable_regular_box': False, | |
| 'topN': topN, | |
| 'min_ppl_score': -1.8, | |
| 'min_clip_score': 0.30, | |
| 'min_mask_area': 2500, | |
| } | |
| dense_captions = self.inference(image, prompt, | |
| controls=None, | |
| disable_gpt=True, | |
| verbose=verbose, | |
| is_densecap=True, | |
| args=densecap_args) | |
| print('Process Dense Captioning: \n', dense_captions) | |
| dense_captions = list(filter(lambda x: x['ppl_score'] / (1+len(x['generated_captions']['raw_caption'].split())) >= densecap_args['min_ppl_score'], dense_captions)) | |
| dense_captions = list(filter(lambda x: x['clip_score'] >= densecap_args['min_clip_score'], dense_captions)) | |
| dense_cap_prompt = [] | |
| for cap in dense_captions: | |
| x, y, w, h = cap['bbox'] | |
| cx, cy = x + w/2, (y + h/2) | |
| dense_cap_prompt.append("({}: X:{:.0f}, Y:{:.0f}, Width:{:.0f}, Height:{:.0f})".format(cap['generated_captions']['raw_caption'], cx, cy, w, h)) | |
| if verbose: | |
| all_masks = [np.array(item['mask'].convert('P')) for item in dense_captions] | |
| new_image = mask_painter_foreground_all(np.array(image), all_masks, background_alpha=0.4) | |
| save_path = 'result/dense_caption_mask.png' | |
| Image.fromarray(new_image).save(save_path) | |
| print(f'Dense captioning mask saved in {save_path}') | |
| vis_path = 'result/dense_caption_vis_{}.png'.format(time.time()) | |
| dense_cap_painter_input = [{'bbox': xywh_to_x1y1x2y2(cap['bbox']), | |
| 'caption': cap['generated_captions']['raw_caption']} for cap in dense_captions] | |
| draw_bbox(load_image(image, return_type='numpy'), vis_path, dense_cap_painter_input, show_caption=True) | |
| print(f'Dense Captioning visualization saved in {vis_path}') | |
| return ','.join(dense_cap_prompt) | |
| def parse_ocr(self, image, thres=0.2): | |
| width, height = get_image_shape(image) | |
| image = load_image(image, return_type='numpy') | |
| bounds = self.ocr_reader.readtext(image) | |
| bounds = [bound for bound in bounds if bound[2] > thres] | |
| print('Process OCR Text:\n', bounds) | |
| ocr_prompt = [] | |
| for box, text, conf in bounds: | |
| p0, p1, p2, p3 = box | |
| ocr_prompt.append('(\"{}\": X:{:.0f}, Y:{:.0f})'.format(text, (p0[0]+p1[0]+p2[0]+p3[0])/4, (p0[1]+p1[1]+p2[1]+p3[1])/4)) | |
| ocr_prompt = '\n'.join(ocr_prompt) | |
| # ocr_prompt = self.text_refiner.llm(f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)').strip() | |
| # ocr_prefix1 = f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)' | |
| # ocr_prefix2 = f'Please group these individual words into 1-3 phrases, given scene texts with their locations: {ocr_prompt}. You return is one or several strings and infer their locations. (only give me your answer like (“man working”, X: value, Y: value), do not show explanination)' | |
| # ocr_prefix4 = f'summarize the individual scene text words detected by OCR tools into a fluent sentence based on their positions and distances. You should strictly describe all of the given scene text words. Do not miss any given word. Do not create non-exist words. Do not appear numeric positions. The individual words are given:\n{ocr_prompt}\n' | |
| # ocr_prefix3 = f'combine the individual scene text words detected by OCR tools into one/several fluent phrases/sentences based on their positions and distances. You should strictly copy or correct all of the given scene text words. Do not miss any given word. Do not create non-exist words. The response is several strings seperate with their location (X, Y), each of which represents a phrase. The individual words are given:\n{ocr_prompt}\n' | |
| # response = self.text_refiner.llm(ocr_prefix3).strip() if len(ocr_prompt) else "" | |
| return ocr_prompt | |
| def inference_cap_everything(self, image, verbose=False): | |
| image = load_image(image, return_type='pil') | |
| image = image_resize(image, res=1024) | |
| width, height = get_image_shape(image) | |
| other_args = {'text_prompt': ""} if self.require_caption_prompt else {} | |
| img_caption = self.captioner.inference(image, filter=False, args=other_args)['caption'] | |
| dense_caption_prompt = self.parse_dense_caption(image, topN=10, verbose=verbose, reference_caption=[]) | |
| scene_text_prompt = self.parse_ocr(image, thres=0.2) | |
| # scene_text_prompt = "N/A" | |
| # the summarize_prompt is modified from https://github.com/JialianW/GRiT and https://github.com/showlab/Image2Paragraph | |
| summarize_prompt = "Imagine you are a blind but intelligent image captioner. You should generate a descriptive, coherent and human-like paragraph based on the given information (a,b,c,d) instead of imagination:\na) Image Resolution: {image_size}\nb) Image Caption:{image_caption}\nc) Dense Caption: {dense_caption}\nd) Scene Text: {scene_text}\nThere are some rules for your response: Show objects with their attributes (e.g. position, color, size, shape, texture).\nPrimarily describe common objects with large size.\nProvide context of the image.\nShow relative position between objects.\nLess than 6 sentences.\nDo not appear number.\nDo not describe any individual letter.\nDo not show the image resolution.\nIngore the white background." | |
| prompt = summarize_prompt.format(**{ | |
| "image_size": "width {} height {}".format(width, height), | |
| "image_caption":img_caption, | |
| "dense_caption": dense_caption_prompt, | |
| "scene_text": scene_text_prompt}) | |
| print(f'caption everything prompt: {prompt}') | |
| response = self.text_refiner.llm(prompt).strip() | |
| # chinese_response = self.text_refiner.llm('Translate it into Chinese: {}'.format(response)).strip() | |
| return response | |
| if __name__ == "__main__": | |
| from caption_anything.utils.parser import parse_augment | |
| args = parse_augment() | |
| image_path = 'result/wt/memes/87226084.jpg' | |
| image = Image.open(image_path) | |
| prompts = [ | |
| { | |
| "prompt_type": ["click"], | |
| "input_point": [[500, 300], [200, 500]], | |
| "input_label": [1, 0], | |
| "multimask_output": "True", | |
| }, | |
| # { | |
| # "prompt_type": ["click"], | |
| # "input_point": [[300, 800]], | |
| # "input_label": [1], | |
| # "multimask_output": "True", | |
| # } | |
| ] | |
| controls = { | |
| "length": "30", | |
| "sentiment": "positive", | |
| # "imagination": "True", | |
| "imagination": "False", | |
| "language": "English", | |
| } | |
| model = CaptionAnything(args, os.environ['OPENAI_API_KEY']) | |
| img_dir = 'test_images/memes' | |
| for image_file in os.listdir(img_dir): | |
| image_path = os.path.join(img_dir, image_file) | |
| print('image_path:', image_path) | |
| paragraph = model.inference_cap_everything(image_path, verbose=True) | |
| print('Caption Everything:\n', paragraph) | |
| ocr = model.parse_ocr(image_path) | |
| print('OCR', ocr) |