abdeljalilELmajjodi commited on
Commit
186fd60
·
verified ·
1 Parent(s): edf3f62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -219
app.py CHANGED
@@ -1,224 +1,114 @@
1
- # Copyright (c) AtlasIA.
2
- #
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- import os
6
- import numpy as np
7
- from urllib3.exceptions import HTTPError
8
- os.system('pip install dashscope modelscope oss2 -U')
9
-
10
- from argparse import ArgumentParser
11
- from pathlib import Path
12
-
13
- import copy
14
  import gradio as gr
15
- import oss2
 
 
 
 
 
 
16
  import os
17
- import re
18
- import secrets
19
- import tempfile
20
- import requests
21
- from http import HTTPStatus
22
- from dashscope import MultiModalConversation
23
- import dashscope
24
-
25
- API_KEY = os.environ['API_KEY']
26
- dashscope.api_key = API_KEY
27
-
28
- BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
29
- PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
30
-
31
-
32
- def _get_args():
33
- parser = ArgumentParser()
34
- parser.add_argument("--revision", type=str, default=REVISION)
35
- parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
36
-
37
- parser.add_argument("--share", action="store_true", default=False,
38
- help="Create a publicly shareable link for the interface.")
39
- parser.add_argument("--inbrowser", action="store_true", default=False,
40
- help="Automatically launch the interface in a new tab on the default browser.")
41
- parser.add_argument("--server-port", type=int, default=7860,
42
- help="Demo server port.")
43
- parser.add_argument("--server-name", type=str, default="127.0.0.1",
44
- help="Demo server name.")
45
-
46
- args = parser.parse_args()
47
- return args
48
-
49
- def _parse_text(text):
50
- lines = text.split("\n")
51
- lines = [line for line in lines if line != ""]
52
- count = 0
53
- for i, line in enumerate(lines):
54
- if "```" in line:
55
- count += 1
56
- items = line.split("`")
57
- if count % 2 == 1:
58
- lines[i] = f'<pre><code class="language-{items[-1]}">'
59
- else:
60
- lines[i] = f"<br></code></pre>"
61
- else:
62
- if i > 0:
63
- if count % 2 == 1:
64
- line = line.replace("`", r"\`")
65
- line = line.replace("<", "&lt;")
66
- line = line.replace(">", "&gt;")
67
- line = line.replace(" ", "&nbsp;")
68
- line = line.replace("*", "&ast;")
69
- line = line.replace("_", "&lowbar;")
70
- line = line.replace("-", "&#45;")
71
- line = line.replace(".", "&#46;")
72
- line = line.replace("!", "&#33;")
73
- line = line.replace("(", "&#40;")
74
- line = line.replace(")", "&#41;")
75
- line = line.replace("$", "&#36;")
76
- lines[i] = "<br>" + line
77
- text = "".join(lines)
78
- return text
79
-
80
-
81
- def _remove_image_special(text):
82
- text = text.replace('<ref>', '').replace('</ref>', '')
83
- return re.sub(r'<box>.*?(</box>|$)', '', text)
84
-
85
-
86
 
87
- def _launch_demo(args):
88
- uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
89
- Path(tempfile.gettempdir()) / "gradio"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
-
92
- def predict(_chatbot, task_history):
93
- chat_query = _chatbot[-1][0]
94
- query = task_history[-1][0]
95
- if len(chat_query) == 0:
96
- _chatbot.pop()
97
- task_history.pop()
98
- return _chatbot
99
- print("User: " + _parse_text(query))
100
- history_cp = copy.deepcopy(task_history)
101
- full_response = ""
102
- messages = []
103
- content = []
104
- for q, a in history_cp:
105
- if isinstance(q, (tuple, list)):
106
- content.append({'image': f'file://{q[0]}'})
107
- else:
108
- content.append({'text': q})
109
- messages.append({'role': 'user', 'content': content})
110
- messages.append({'role': 'assistant', 'content': [{'text': a}]})
111
- content = []
112
- messages.pop()
113
- responses = MultiModalConversation.call(
114
- model='AtlasOCR', messages=messages, stream=True,
115
- )
116
- for response in responses:
117
- if not response.status_code == HTTPStatus.OK:
118
- raise HTTPError(f'response.code: {response.code}\nresponse.message: {response.message}')
119
- response = response.output.choices[0].message.content
120
- response_text = []
121
- for ele in response:
122
- if 'text' in ele:
123
- response_text.append(ele['text'])
124
- elif 'box' in ele:
125
- response_text.append(ele['box'])
126
- response_text = ''.join(response_text)
127
- _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(response_text))
128
- yield _chatbot
129
-
130
- if len(response) > 1:
131
- result_image = response[-1]['result_image']
132
- resp = requests.get(result_image)
133
- os.makedirs(uploaded_file_dir, exist_ok=True)
134
- name = f"tmp{secrets.token_hex(20)}.jpg"
135
- filename = os.path.join(uploaded_file_dir, name)
136
- with open(filename, 'wb') as f:
137
- f.write(resp.content)
138
- response = ''.join(r['box'] if 'box' in r else r['text'] for r in response[:-1])
139
- _chatbot.append((None, (filename,)))
140
- else:
141
- response = response[0]['text']
142
- _chatbot[-1] = (_parse_text(chat_query), response)
143
- full_response = _parse_text(response)
144
-
145
- task_history[-1] = (query, full_response)
146
- print("AtlasOCR-Chat: " + _parse_text(full_response))
147
- yield _chatbot
148
-
149
-
150
- def regenerate(_chatbot, task_history):
151
- if not task_history:
152
- return _chatbot
153
- item = task_history[-1]
154
- if item[1] is None:
155
- return _chatbot
156
- task_history[-1] = (item[0], None)
157
- chatbot_item = _chatbot.pop(-1)
158
- if chatbot_item[0] is None:
159
- _chatbot[-1] = (_chatbot[-1][0], None)
160
- else:
161
- _chatbot.append((chatbot_item[0], None))
162
- _chatbot_gen = predict(_chatbot, task_history)
163
- for _chatbot in _chatbot_gen:
164
- yield _chatbot
165
-
166
- def add_text(history, task_history, text):
167
- task_text = text
168
- history = history if history is not None else []
169
- task_history = task_history if task_history is not None else []
170
- history = history + [(_parse_text(text), None)]
171
- task_history = task_history + [(task_text, None)]
172
- return history, task_history, ""
173
-
174
- def add_file(history, task_history, file):
175
- history = history if history is not None else []
176
- task_history = task_history if task_history is not None else []
177
- history = history + [((file.name,), None)]
178
- task_history = task_history + [((file.name,), None)]
179
- return history, task_history
180
-
181
- def reset_user_input():
182
- return gr.update(value="")
183
-
184
- def reset_state(task_history):
185
- task_history.clear()
186
- return []
187
-
188
- with gr.Blocks() as demo:
189
- gr.Markdown("""<center><font size=3> AtlasOCR Demo </center>""")
190
-
191
- chatbot = gr.Chatbot(label='AtlasOCR', elem_classes="control-height", height=500)
192
- query = gr.Textbox(lines=2, label='Input')
193
- task_history = gr.State([])
194
-
195
- with gr.Row():
196
- addfile_btn = gr.UploadButton("📁 Upload", file_types=["image"])
197
- submit_btn = gr.Button("🚀 Submit")
198
- regen_btn = gr.Button("🤔️ Regenerate")
199
- empty_bin = gr.Button("🧹 Clear History")
200
-
201
- submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
202
- predict, [chatbot, task_history], [chatbot], show_progress=True
203
- )
204
- submit_btn.click(reset_user_input, [], [query])
205
- empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
206
- regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
207
- addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
208
-
209
-
210
- demo.queue(default_concurrency_limit=40).launch(
211
- share=args.share,
212
- # inbrowser=args.inbrowser,
213
- # server_port=args.server_port,
214
- # server_name=args.server_name,
215
  )
216
-
217
-
218
- def main():
219
- args = _get_args()
220
- _launch_demo(args)
221
-
222
-
223
- if __name__ == '__main__':
224
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import time
3
+ import spaces
4
+ from PIL import Image
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+ import torch
8
+ import uuid
9
  import os
10
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Load model and processor
13
+ # model_name = "NAMAA-Space/Qari-OCR-0.1-VL-2B-Instruct"
14
+ model_name = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"
15
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
16
+ model_name,
17
+ torch_dtype="auto",
18
+ device_map="cuda"
19
+ )
20
+ processor = AutoProcessor.from_pretrained(model_name)
21
+ max_tokens = 2000
22
+
23
+
24
+
25
+ @spaces.GPU
26
+ def perform_ocr(image):
27
+ inputArray = np.any(image)
28
+ if inputArray == False:
29
+ return "Error Processing"
30
+ """Process image and extract text using OCR model"""
31
+ image = Image.fromarray(image)
32
+ src = str(uuid.uuid4()) + ".png"
33
+ prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
34
+ image.save(src)
35
+
36
+ messages = [
37
+ {
38
+ "role": "user",
39
+ "content": [
40
+ {"type": "image", "image": f"file://{src}"},
41
+ {"type": "text", "text": prompt},
42
+ ],
43
+ }
44
+ ]
45
+
46
+ # Process inputs
47
+ text = processor.apply_chat_template(
48
+ messages, tokenize=False, add_generation_prompt=True
49
  )
50
+ image_inputs, video_inputs = process_vision_info(messages)
51
+ inputs = processor(
52
+ text=[text],
53
+ images=image_inputs,
54
+ videos=video_inputs,
55
+ padding=True,
56
+ return_tensors="pt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
+ inputs = inputs.to("cuda")
59
+
60
+ # Generate text
61
+ generated_ids = model.generate(**inputs, max_new_tokens=max_tokens, use_cache=True)
62
+ generated_ids_trimmed = [
63
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
64
+ ]
65
+ output_text = processor.batch_decode(
66
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
67
+ )[0]
68
+
69
+ # Cleanup
70
+ os.remove(src)
71
+ return output_text
72
+
73
+ # Create Gradio interface
74
+ with gr.Blocks(title="Qari Arabic OCR") as demo:
75
+ gr.Markdown("# Qari Arabic OCR")
76
+ gr.Markdown("Upload an image to extract Arabic text in real-time. This model is specialized for Arabic document OCR.")
77
+
78
+ with gr.Row():
79
+ with gr.Column(scale=1):
80
+ # Input image
81
+ image_input = gr.Image(type="numpy", label="Upload Image")
82
+
83
+ # Example gallery
84
+ gr.Examples(
85
+ examples=[
86
+ ["2.jpg"],
87
+ ["3.jpg"]
88
+ ],
89
+ inputs=image_input,
90
+ label="Example Images",
91
+ examples_per_page=4
92
+ )
93
+
94
+ # Submit button
95
+ submit_btn = gr.Button("Extract Text")
96
+
97
+ with gr.Column(scale=1):
98
+ # Output text
99
+ output = gr.Textbox(label="Extracted Text", lines=20, show_copy_button=True)
100
+
101
+ # Model details
102
+ with gr.Accordion("Model Information", open=False):
103
+ gr.Markdown("""
104
+ **Model:** Qari-OCR-0.1-VL-2B-Instruct
105
+ **Description:** Arabic OCR model based on Qwen2-VL architecture
106
+ **Size:** 2B parameters
107
+ **Context window:** Supports up to 2000 output tokens
108
+ """)
109
+
110
+ # Set up processing flow
111
+ submit_btn.click(fn=perform_ocr, inputs=image_input, outputs=output)
112
+ image_input.change(fn=perform_ocr, inputs=image_input, outputs=output)
113
+
114
+ demo.launch()