Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,224 +1,114 @@
|
|
| 1 |
-
# Copyright (c) AtlasIA.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the license found in the
|
| 4 |
-
# LICENSE file in the root directory of this source tree.
|
| 5 |
-
import os
|
| 6 |
-
import numpy as np
|
| 7 |
-
from urllib3.exceptions import HTTPError
|
| 8 |
-
os.system('pip install dashscope modelscope oss2 -U')
|
| 9 |
-
|
| 10 |
-
from argparse import ArgumentParser
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
|
| 13 |
-
import copy
|
| 14 |
import gradio as gr
|
| 15 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import os
|
| 17 |
-
import
|
| 18 |
-
import secrets
|
| 19 |
-
import tempfile
|
| 20 |
-
import requests
|
| 21 |
-
from http import HTTPStatus
|
| 22 |
-
from dashscope import MultiModalConversation
|
| 23 |
-
import dashscope
|
| 24 |
-
|
| 25 |
-
API_KEY = os.environ['API_KEY']
|
| 26 |
-
dashscope.api_key = API_KEY
|
| 27 |
-
|
| 28 |
-
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
|
| 29 |
-
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def _get_args():
|
| 33 |
-
parser = ArgumentParser()
|
| 34 |
-
parser.add_argument("--revision", type=str, default=REVISION)
|
| 35 |
-
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
|
| 36 |
-
|
| 37 |
-
parser.add_argument("--share", action="store_true", default=False,
|
| 38 |
-
help="Create a publicly shareable link for the interface.")
|
| 39 |
-
parser.add_argument("--inbrowser", action="store_true", default=False,
|
| 40 |
-
help="Automatically launch the interface in a new tab on the default browser.")
|
| 41 |
-
parser.add_argument("--server-port", type=int, default=7860,
|
| 42 |
-
help="Demo server port.")
|
| 43 |
-
parser.add_argument("--server-name", type=str, default="127.0.0.1",
|
| 44 |
-
help="Demo server name.")
|
| 45 |
-
|
| 46 |
-
args = parser.parse_args()
|
| 47 |
-
return args
|
| 48 |
-
|
| 49 |
-
def _parse_text(text):
|
| 50 |
-
lines = text.split("\n")
|
| 51 |
-
lines = [line for line in lines if line != ""]
|
| 52 |
-
count = 0
|
| 53 |
-
for i, line in enumerate(lines):
|
| 54 |
-
if "```" in line:
|
| 55 |
-
count += 1
|
| 56 |
-
items = line.split("`")
|
| 57 |
-
if count % 2 == 1:
|
| 58 |
-
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
| 59 |
-
else:
|
| 60 |
-
lines[i] = f"<br></code></pre>"
|
| 61 |
-
else:
|
| 62 |
-
if i > 0:
|
| 63 |
-
if count % 2 == 1:
|
| 64 |
-
line = line.replace("`", r"\`")
|
| 65 |
-
line = line.replace("<", "<")
|
| 66 |
-
line = line.replace(">", ">")
|
| 67 |
-
line = line.replace(" ", " ")
|
| 68 |
-
line = line.replace("*", "*")
|
| 69 |
-
line = line.replace("_", "_")
|
| 70 |
-
line = line.replace("-", "-")
|
| 71 |
-
line = line.replace(".", ".")
|
| 72 |
-
line = line.replace("!", "!")
|
| 73 |
-
line = line.replace("(", "(")
|
| 74 |
-
line = line.replace(")", ")")
|
| 75 |
-
line = line.replace("$", "$")
|
| 76 |
-
lines[i] = "<br>" + line
|
| 77 |
-
text = "".join(lines)
|
| 78 |
-
return text
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def _remove_image_special(text):
|
| 82 |
-
text = text.replace('<ref>', '').replace('</ref>', '')
|
| 83 |
-
return re.sub(r'<box>.*?(</box>|$)', '', text)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
return _chatbot
|
| 99 |
-
print("User: " + _parse_text(query))
|
| 100 |
-
history_cp = copy.deepcopy(task_history)
|
| 101 |
-
full_response = ""
|
| 102 |
-
messages = []
|
| 103 |
-
content = []
|
| 104 |
-
for q, a in history_cp:
|
| 105 |
-
if isinstance(q, (tuple, list)):
|
| 106 |
-
content.append({'image': f'file://{q[0]}'})
|
| 107 |
-
else:
|
| 108 |
-
content.append({'text': q})
|
| 109 |
-
messages.append({'role': 'user', 'content': content})
|
| 110 |
-
messages.append({'role': 'assistant', 'content': [{'text': a}]})
|
| 111 |
-
content = []
|
| 112 |
-
messages.pop()
|
| 113 |
-
responses = MultiModalConversation.call(
|
| 114 |
-
model='AtlasOCR', messages=messages, stream=True,
|
| 115 |
-
)
|
| 116 |
-
for response in responses:
|
| 117 |
-
if not response.status_code == HTTPStatus.OK:
|
| 118 |
-
raise HTTPError(f'response.code: {response.code}\nresponse.message: {response.message}')
|
| 119 |
-
response = response.output.choices[0].message.content
|
| 120 |
-
response_text = []
|
| 121 |
-
for ele in response:
|
| 122 |
-
if 'text' in ele:
|
| 123 |
-
response_text.append(ele['text'])
|
| 124 |
-
elif 'box' in ele:
|
| 125 |
-
response_text.append(ele['box'])
|
| 126 |
-
response_text = ''.join(response_text)
|
| 127 |
-
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(response_text))
|
| 128 |
-
yield _chatbot
|
| 129 |
-
|
| 130 |
-
if len(response) > 1:
|
| 131 |
-
result_image = response[-1]['result_image']
|
| 132 |
-
resp = requests.get(result_image)
|
| 133 |
-
os.makedirs(uploaded_file_dir, exist_ok=True)
|
| 134 |
-
name = f"tmp{secrets.token_hex(20)}.jpg"
|
| 135 |
-
filename = os.path.join(uploaded_file_dir, name)
|
| 136 |
-
with open(filename, 'wb') as f:
|
| 137 |
-
f.write(resp.content)
|
| 138 |
-
response = ''.join(r['box'] if 'box' in r else r['text'] for r in response[:-1])
|
| 139 |
-
_chatbot.append((None, (filename,)))
|
| 140 |
-
else:
|
| 141 |
-
response = response[0]['text']
|
| 142 |
-
_chatbot[-1] = (_parse_text(chat_query), response)
|
| 143 |
-
full_response = _parse_text(response)
|
| 144 |
-
|
| 145 |
-
task_history[-1] = (query, full_response)
|
| 146 |
-
print("AtlasOCR-Chat: " + _parse_text(full_response))
|
| 147 |
-
yield _chatbot
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def regenerate(_chatbot, task_history):
|
| 151 |
-
if not task_history:
|
| 152 |
-
return _chatbot
|
| 153 |
-
item = task_history[-1]
|
| 154 |
-
if item[1] is None:
|
| 155 |
-
return _chatbot
|
| 156 |
-
task_history[-1] = (item[0], None)
|
| 157 |
-
chatbot_item = _chatbot.pop(-1)
|
| 158 |
-
if chatbot_item[0] is None:
|
| 159 |
-
_chatbot[-1] = (_chatbot[-1][0], None)
|
| 160 |
-
else:
|
| 161 |
-
_chatbot.append((chatbot_item[0], None))
|
| 162 |
-
_chatbot_gen = predict(_chatbot, task_history)
|
| 163 |
-
for _chatbot in _chatbot_gen:
|
| 164 |
-
yield _chatbot
|
| 165 |
-
|
| 166 |
-
def add_text(history, task_history, text):
|
| 167 |
-
task_text = text
|
| 168 |
-
history = history if history is not None else []
|
| 169 |
-
task_history = task_history if task_history is not None else []
|
| 170 |
-
history = history + [(_parse_text(text), None)]
|
| 171 |
-
task_history = task_history + [(task_text, None)]
|
| 172 |
-
return history, task_history, ""
|
| 173 |
-
|
| 174 |
-
def add_file(history, task_history, file):
|
| 175 |
-
history = history if history is not None else []
|
| 176 |
-
task_history = task_history if task_history is not None else []
|
| 177 |
-
history = history + [((file.name,), None)]
|
| 178 |
-
task_history = task_history + [((file.name,), None)]
|
| 179 |
-
return history, task_history
|
| 180 |
-
|
| 181 |
-
def reset_user_input():
|
| 182 |
-
return gr.update(value="")
|
| 183 |
-
|
| 184 |
-
def reset_state(task_history):
|
| 185 |
-
task_history.clear()
|
| 186 |
-
return []
|
| 187 |
-
|
| 188 |
-
with gr.Blocks() as demo:
|
| 189 |
-
gr.Markdown("""<center><font size=3> AtlasOCR Demo </center>""")
|
| 190 |
-
|
| 191 |
-
chatbot = gr.Chatbot(label='AtlasOCR', elem_classes="control-height", height=500)
|
| 192 |
-
query = gr.Textbox(lines=2, label='Input')
|
| 193 |
-
task_history = gr.State([])
|
| 194 |
-
|
| 195 |
-
with gr.Row():
|
| 196 |
-
addfile_btn = gr.UploadButton("📁 Upload", file_types=["image"])
|
| 197 |
-
submit_btn = gr.Button("🚀 Submit")
|
| 198 |
-
regen_btn = gr.Button("🤔️ Regenerate")
|
| 199 |
-
empty_bin = gr.Button("🧹 Clear History")
|
| 200 |
-
|
| 201 |
-
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
|
| 202 |
-
predict, [chatbot, task_history], [chatbot], show_progress=True
|
| 203 |
-
)
|
| 204 |
-
submit_btn.click(reset_user_input, [], [query])
|
| 205 |
-
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|
| 206 |
-
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
|
| 207 |
-
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
demo.queue(default_concurrency_limit=40).launch(
|
| 211 |
-
share=args.share,
|
| 212 |
-
# inbrowser=args.inbrowser,
|
| 213 |
-
# server_port=args.server_port,
|
| 214 |
-
# server_name=args.server_name,
|
| 215 |
)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import time
|
| 3 |
+
import spaces
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
| 6 |
+
from qwen_vl_utils import process_vision_info
|
| 7 |
+
import torch
|
| 8 |
+
import uuid
|
| 9 |
import os
|
| 10 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
# Load model and processor
|
| 13 |
+
# model_name = "NAMAA-Space/Qari-OCR-0.1-VL-2B-Instruct"
|
| 14 |
+
model_name = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"
|
| 15 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 16 |
+
model_name,
|
| 17 |
+
torch_dtype="auto",
|
| 18 |
+
device_map="cuda"
|
| 19 |
+
)
|
| 20 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
| 21 |
+
max_tokens = 2000
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@spaces.GPU
|
| 26 |
+
def perform_ocr(image):
|
| 27 |
+
inputArray = np.any(image)
|
| 28 |
+
if inputArray == False:
|
| 29 |
+
return "Error Processing"
|
| 30 |
+
"""Process image and extract text using OCR model"""
|
| 31 |
+
image = Image.fromarray(image)
|
| 32 |
+
src = str(uuid.uuid4()) + ".png"
|
| 33 |
+
prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
|
| 34 |
+
image.save(src)
|
| 35 |
+
|
| 36 |
+
messages = [
|
| 37 |
+
{
|
| 38 |
+
"role": "user",
|
| 39 |
+
"content": [
|
| 40 |
+
{"type": "image", "image": f"file://{src}"},
|
| 41 |
+
{"type": "text", "text": prompt},
|
| 42 |
+
],
|
| 43 |
+
}
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# Process inputs
|
| 47 |
+
text = processor.apply_chat_template(
|
| 48 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 49 |
)
|
| 50 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 51 |
+
inputs = processor(
|
| 52 |
+
text=[text],
|
| 53 |
+
images=image_inputs,
|
| 54 |
+
videos=video_inputs,
|
| 55 |
+
padding=True,
|
| 56 |
+
return_tensors="pt",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
)
|
| 58 |
+
inputs = inputs.to("cuda")
|
| 59 |
+
|
| 60 |
+
# Generate text
|
| 61 |
+
generated_ids = model.generate(**inputs, max_new_tokens=max_tokens, use_cache=True)
|
| 62 |
+
generated_ids_trimmed = [
|
| 63 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 64 |
+
]
|
| 65 |
+
output_text = processor.batch_decode(
|
| 66 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 67 |
+
)[0]
|
| 68 |
+
|
| 69 |
+
# Cleanup
|
| 70 |
+
os.remove(src)
|
| 71 |
+
return output_text
|
| 72 |
+
|
| 73 |
+
# Create Gradio interface
|
| 74 |
+
with gr.Blocks(title="Qari Arabic OCR") as demo:
|
| 75 |
+
gr.Markdown("# Qari Arabic OCR")
|
| 76 |
+
gr.Markdown("Upload an image to extract Arabic text in real-time. This model is specialized for Arabic document OCR.")
|
| 77 |
+
|
| 78 |
+
with gr.Row():
|
| 79 |
+
with gr.Column(scale=1):
|
| 80 |
+
# Input image
|
| 81 |
+
image_input = gr.Image(type="numpy", label="Upload Image")
|
| 82 |
+
|
| 83 |
+
# Example gallery
|
| 84 |
+
gr.Examples(
|
| 85 |
+
examples=[
|
| 86 |
+
["2.jpg"],
|
| 87 |
+
["3.jpg"]
|
| 88 |
+
],
|
| 89 |
+
inputs=image_input,
|
| 90 |
+
label="Example Images",
|
| 91 |
+
examples_per_page=4
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Submit button
|
| 95 |
+
submit_btn = gr.Button("Extract Text")
|
| 96 |
+
|
| 97 |
+
with gr.Column(scale=1):
|
| 98 |
+
# Output text
|
| 99 |
+
output = gr.Textbox(label="Extracted Text", lines=20, show_copy_button=True)
|
| 100 |
+
|
| 101 |
+
# Model details
|
| 102 |
+
with gr.Accordion("Model Information", open=False):
|
| 103 |
+
gr.Markdown("""
|
| 104 |
+
**Model:** Qari-OCR-0.1-VL-2B-Instruct
|
| 105 |
+
**Description:** Arabic OCR model based on Qwen2-VL architecture
|
| 106 |
+
**Size:** 2B parameters
|
| 107 |
+
**Context window:** Supports up to 2000 output tokens
|
| 108 |
+
""")
|
| 109 |
+
|
| 110 |
+
# Set up processing flow
|
| 111 |
+
submit_btn.click(fn=perform_ocr, inputs=image_input, outputs=output)
|
| 112 |
+
image_input.change(fn=perform_ocr, inputs=image_input, outputs=output)
|
| 113 |
+
|
| 114 |
+
demo.launch()
|