Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
·
108f2df
1
Parent(s):
b7e072a
share ocr_reader to accelerate inferenec
Browse files- app.py +11 -3
- caption_anything/captioner/blip2.py +2 -2
- caption_anything/model.py +8 -5
app.py
CHANGED
|
@@ -17,7 +17,7 @@ from caption_anything.text_refiner import build_text_refiner
|
|
| 17 |
from caption_anything.segmenter import build_segmenter
|
| 18 |
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
| 19 |
from segment_anything import sam_model_registry
|
| 20 |
-
|
| 21 |
|
| 22 |
args = parse_augment()
|
| 23 |
args.segmenter = "huge"
|
|
@@ -30,6 +30,8 @@ else:
|
|
| 30 |
|
| 31 |
shared_captioner = build_captioner(args.captioner, args.device, args)
|
| 32 |
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
|
|
|
|
|
|
| 33 |
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
|
| 34 |
shared_chatbot_tools = build_chatbot_tools(tools_dict)
|
| 35 |
|
|
@@ -57,13 +59,13 @@ class ImageSketcher(gr.Image):
|
|
| 57 |
return super().preprocess(x)
|
| 58 |
|
| 59 |
|
| 60 |
-
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
|
| 61 |
session_id=None):
|
| 62 |
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
| 63 |
captioner = captioner
|
| 64 |
if session_id is not None:
|
| 65 |
print('Init caption anything for session {}'.format(session_id))
|
| 66 |
-
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
|
| 67 |
|
| 68 |
|
| 69 |
def init_openai_api_key(api_key=""):
|
|
@@ -146,6 +148,7 @@ def upload_callback(image_input, state, visual_chatgpt=None):
|
|
| 146 |
api_key="",
|
| 147 |
captioner=shared_captioner,
|
| 148 |
sam_model=shared_sam_model,
|
|
|
|
| 149 |
session_id=iface.app_id
|
| 150 |
)
|
| 151 |
model.segmenter.set_image(image_input)
|
|
@@ -154,6 +157,7 @@ def upload_callback(image_input, state, visual_chatgpt=None):
|
|
| 154 |
input_size = model.input_size
|
| 155 |
|
| 156 |
if visual_chatgpt is not None:
|
|
|
|
| 157 |
new_image_path = get_new_image_name('chat_image', func_name='upload')
|
| 158 |
image_input.save(new_image_path)
|
| 159 |
visual_chatgpt.current_image = new_image_path
|
|
@@ -192,6 +196,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
| 192 |
api_key="",
|
| 193 |
captioner=shared_captioner,
|
| 194 |
sam_model=shared_sam_model,
|
|
|
|
| 195 |
text_refiner=text_refiner,
|
| 196 |
session_id=iface.app_id
|
| 197 |
)
|
|
@@ -213,6 +218,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
| 213 |
x, y = input_points[-1]
|
| 214 |
|
| 215 |
if visual_chatgpt is not None:
|
|
|
|
| 216 |
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
|
| 217 |
Image.open(out["crop_save_path"]).save(new_crop_save_path)
|
| 218 |
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
|
|
@@ -273,6 +279,7 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali
|
|
| 273 |
api_key="",
|
| 274 |
captioner=shared_captioner,
|
| 275 |
sam_model=shared_sam_model,
|
|
|
|
| 276 |
text_refiner=text_refiner,
|
| 277 |
session_id=iface.app_id
|
| 278 |
)
|
|
@@ -325,6 +332,7 @@ def cap_everything(image_input, visual_chatgpt, text_refiner):
|
|
| 325 |
api_key="",
|
| 326 |
captioner=shared_captioner,
|
| 327 |
sam_model=shared_sam_model,
|
|
|
|
| 328 |
text_refiner=text_refiner,
|
| 329 |
session_id=iface.app_id
|
| 330 |
)
|
|
|
|
| 17 |
from caption_anything.segmenter import build_segmenter
|
| 18 |
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
| 19 |
from segment_anything import sam_model_registry
|
| 20 |
+
import easyocr
|
| 21 |
|
| 22 |
args = parse_augment()
|
| 23 |
args.segmenter = "huge"
|
|
|
|
| 30 |
|
| 31 |
shared_captioner = build_captioner(args.captioner, args.device, args)
|
| 32 |
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
| 33 |
+
ocr_lang = ["ch_tra", "en"]
|
| 34 |
+
shared_ocr_reader = easyocr.Reader(ocr_lang)
|
| 35 |
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
|
| 36 |
shared_chatbot_tools = build_chatbot_tools(tools_dict)
|
| 37 |
|
|
|
|
| 59 |
return super().preprocess(x)
|
| 60 |
|
| 61 |
|
| 62 |
+
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
|
| 63 |
session_id=None):
|
| 64 |
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
| 65 |
captioner = captioner
|
| 66 |
if session_id is not None:
|
| 67 |
print('Init caption anything for session {}'.format(session_id))
|
| 68 |
+
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
|
| 69 |
|
| 70 |
|
| 71 |
def init_openai_api_key(api_key=""):
|
|
|
|
| 148 |
api_key="",
|
| 149 |
captioner=shared_captioner,
|
| 150 |
sam_model=shared_sam_model,
|
| 151 |
+
ocr_reader=shared_ocr_reader,
|
| 152 |
session_id=iface.app_id
|
| 153 |
)
|
| 154 |
model.segmenter.set_image(image_input)
|
|
|
|
| 157 |
input_size = model.input_size
|
| 158 |
|
| 159 |
if visual_chatgpt is not None:
|
| 160 |
+
print('upload_callback: add caption to chatGPT memory')
|
| 161 |
new_image_path = get_new_image_name('chat_image', func_name='upload')
|
| 162 |
image_input.save(new_image_path)
|
| 163 |
visual_chatgpt.current_image = new_image_path
|
|
|
|
| 196 |
api_key="",
|
| 197 |
captioner=shared_captioner,
|
| 198 |
sam_model=shared_sam_model,
|
| 199 |
+
ocr_reader=shared_ocr_reader,
|
| 200 |
text_refiner=text_refiner,
|
| 201 |
session_id=iface.app_id
|
| 202 |
)
|
|
|
|
| 218 |
x, y = input_points[-1]
|
| 219 |
|
| 220 |
if visual_chatgpt is not None:
|
| 221 |
+
print('inference_click: add caption to chatGPT memory')
|
| 222 |
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
|
| 223 |
Image.open(out["crop_save_path"]).save(new_crop_save_path)
|
| 224 |
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
|
|
|
|
| 279 |
api_key="",
|
| 280 |
captioner=shared_captioner,
|
| 281 |
sam_model=shared_sam_model,
|
| 282 |
+
ocr_reader=shared_ocr_reader,
|
| 283 |
text_refiner=text_refiner,
|
| 284 |
session_id=iface.app_id
|
| 285 |
)
|
|
|
|
| 332 |
api_key="",
|
| 333 |
captioner=shared_captioner,
|
| 334 |
sam_model=shared_sam_model,
|
| 335 |
+
ocr_reader=shared_ocr_reader,
|
| 336 |
text_refiner=text_refiner,
|
| 337 |
session_id=iface.app_id
|
| 338 |
)
|
caption_anything/captioner/blip2.py
CHANGED
|
@@ -6,6 +6,7 @@ from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
|
| 6 |
|
| 7 |
from caption_anything.utils.utils import is_platform_win, load_image
|
| 8 |
from .base_captioner import BaseCaptioner
|
|
|
|
| 9 |
|
| 10 |
class BLIP2Captioner(BaseCaptioner):
|
| 11 |
def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
|
|
@@ -33,8 +34,7 @@ class BLIP2Captioner(BaseCaptioner):
|
|
| 33 |
if not self.dialogue:
|
| 34 |
inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
|
| 35 |
out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
|
| 36 |
-
|
| 37 |
-
caption = [caption.strip() for caption in captions][0]
|
| 38 |
if self.enable_filter and filter:
|
| 39 |
print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
|
| 40 |
clip_score = self.filter_caption(image, caption, args['reference_caption'])
|
|
|
|
| 6 |
|
| 7 |
from caption_anything.utils.utils import is_platform_win, load_image
|
| 8 |
from .base_captioner import BaseCaptioner
|
| 9 |
+
import time
|
| 10 |
|
| 11 |
class BLIP2Captioner(BaseCaptioner):
|
| 12 |
def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
|
|
|
|
| 34 |
if not self.dialogue:
|
| 35 |
inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
|
| 36 |
out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
|
| 37 |
+
caption = self.processor.decode(out.sequences[0], skip_special_tokens=True).strip()
|
|
|
|
| 38 |
if self.enable_filter and filter:
|
| 39 |
print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
|
| 40 |
clip_score = self.filter_caption(image, caption, args['reference_caption'])
|
caption_anything/model.py
CHANGED
|
@@ -8,6 +8,7 @@ import numpy as np
|
|
| 8 |
from PIL import Image
|
| 9 |
import easyocr
|
| 10 |
import copy
|
|
|
|
| 11 |
from caption_anything.captioner import build_captioner, BaseCaptioner
|
| 12 |
from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
|
| 13 |
from caption_anything.text_refiner import build_text_refiner
|
|
@@ -16,14 +17,15 @@ from caption_anything.utils.utils import mask_painter_foreground_all, mask_paint
|
|
| 16 |
from caption_anything.utils.densecap_painter import draw_bbox
|
| 17 |
|
| 18 |
class CaptionAnything:
|
| 19 |
-
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
|
| 20 |
self.args = args
|
| 21 |
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
| 22 |
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
| 23 |
self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
self.reader = easyocr.Reader(self.lang)
|
| 27 |
self.text_refiner = None
|
| 28 |
if not args.disable_gpt:
|
| 29 |
if text_refiner is not None:
|
|
@@ -31,6 +33,7 @@ class CaptionAnything:
|
|
| 31 |
elif api_key != "":
|
| 32 |
self.init_refiner(api_key)
|
| 33 |
self.require_caption_prompt = args.captioner == 'blip2'
|
|
|
|
| 34 |
|
| 35 |
@property
|
| 36 |
def image_embedding(self):
|
|
@@ -213,7 +216,7 @@ class CaptionAnything:
|
|
| 213 |
def parse_ocr(self, image, thres=0.2):
|
| 214 |
width, height = get_image_shape(image)
|
| 215 |
image = load_image(image, return_type='numpy')
|
| 216 |
-
bounds = self.
|
| 217 |
bounds = [bound for bound in bounds if bound[2] > thres]
|
| 218 |
print('Process OCR Text:\n', bounds)
|
| 219 |
|
|
@@ -257,7 +260,7 @@ class CaptionAnything:
|
|
| 257 |
if __name__ == "__main__":
|
| 258 |
from caption_anything.utils.parser import parse_augment
|
| 259 |
args = parse_augment()
|
| 260 |
-
image_path = '
|
| 261 |
image = Image.open(image_path)
|
| 262 |
prompts = [
|
| 263 |
{
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
import easyocr
|
| 10 |
import copy
|
| 11 |
+
import time
|
| 12 |
from caption_anything.captioner import build_captioner, BaseCaptioner
|
| 13 |
from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
|
| 14 |
from caption_anything.text_refiner import build_text_refiner
|
|
|
|
| 17 |
from caption_anything.utils.densecap_painter import draw_bbox
|
| 18 |
|
| 19 |
class CaptionAnything:
|
| 20 |
+
def __init__(self, args, api_key="", captioner=None, segmenter=None, ocr_reader=None, text_refiner=None):
|
| 21 |
self.args = args
|
| 22 |
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
| 23 |
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
| 24 |
self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
|
| 25 |
+
self.ocr_lang = ["ch_tra", "en"]
|
| 26 |
+
self.ocr_reader = ocr_reader if ocr_reader is not None else easyocr.Reader(self.ocr_lang)
|
| 27 |
|
| 28 |
+
|
|
|
|
| 29 |
self.text_refiner = None
|
| 30 |
if not args.disable_gpt:
|
| 31 |
if text_refiner is not None:
|
|
|
|
| 33 |
elif api_key != "":
|
| 34 |
self.init_refiner(api_key)
|
| 35 |
self.require_caption_prompt = args.captioner == 'blip2'
|
| 36 |
+
print('text_refiner init time: ', time.time() - t0)
|
| 37 |
|
| 38 |
@property
|
| 39 |
def image_embedding(self):
|
|
|
|
| 216 |
def parse_ocr(self, image, thres=0.2):
|
| 217 |
width, height = get_image_shape(image)
|
| 218 |
image = load_image(image, return_type='numpy')
|
| 219 |
+
bounds = self.ocr_reader.readtext(image)
|
| 220 |
bounds = [bound for bound in bounds if bound[2] > thres]
|
| 221 |
print('Process OCR Text:\n', bounds)
|
| 222 |
|
|
|
|
| 260 |
if __name__ == "__main__":
|
| 261 |
from caption_anything.utils.parser import parse_augment
|
| 262 |
args = parse_augment()
|
| 263 |
+
image_path = 'result/wt/memes/87226084.jpg'
|
| 264 |
image = Image.open(image_path)
|
| 265 |
prompts = [
|
| 266 |
{
|