nielsr HF Staff Claude commited on
Commit
fe64308
Β·
1 Parent(s): 8299bc6

Add KOSMOS-2.5 Document AI Demo

Browse files

- Three interactive modes: Markdown generation, OCR with bounding boxes, and Document Q&A
- Support for both microsoft/kosmos-2.5 and microsoft/kosmos-2.5-chat models
- ZeroGPU integration with @spaces.GPU decorators
- Visual OCR with bounding box overlays
- Professional Gradio interface with tabbed layout

πŸ€– Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

Files changed (3) hide show
  1. README.md +55 -7
  2. app.py +266 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,13 +1,61 @@
1
  ---
2
- title: Kosmos 2.5 Demo
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.44.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: A demo showcasing the abilities of KOSMOS-2.5
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: KOSMOS-2.5 Document AI Demo
3
+ emoji: πŸ“„
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # KOSMOS-2.5 Document AI Demo
14
+
15
+ This Space demonstrates the capabilities of Microsoft's **KOSMOS-2.5**, a multimodal literate model for machine reading of text-intensive images.
16
+
17
+ ## Features
18
+
19
+ πŸ”₯ **Three powerful modes**:
20
+
21
+ 1. **πŸ“ Markdown Generation**: Convert document images to clean markdown format
22
+ 2. **πŸ” OCR with Bounding Boxes**: Extract text with precise spatial coordinates and visualization
23
+ 3. **πŸ’¬ Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat
24
+
25
+ ## What is KOSMOS-2.5?
26
+
27
+ KOSMOS-2.5 is Microsoft's latest document AI model that excels at understanding text-rich images. It can:
28
+
29
+ - Generate spatially-aware text blocks with coordinates
30
+ - Produce structured markdown output that captures document styles
31
+ - Answer questions about document content through the chat variant
32
+
33
+ The model was pre-trained on 357.4 million text-rich document images and achieves performance comparable to much larger models (1.3B vs 7B parameters) on visual question answering benchmarks.
34
+
35
+ ## Example Use Cases
36
+
37
+ - **Receipts**: Extract itemized information or ask "What's the total amount?"
38
+ - **Forms**: Convert to structured format or query specific fields
39
+ - **Articles**: Get clean markdown or ask content-specific questions
40
+ - **Screenshots**: Extract UI text or get information about elements
41
+
42
+ ## Model Information
43
+
44
+ - **Base Model**: [microsoft/kosmos-2.5](https://huggingface.co/microsoft/kosmos-2.5)
45
+ - **Chat Model**: [microsoft/kosmos-2.5-chat](https://huggingface.co/microsoft/kosmos-2.5-chat)
46
+ - **Paper**: [Kosmos-2.5: A Multimodal Literate Model](https://arxiv.org/abs/2309.11419)
47
+
48
+ ## Note
49
+
50
+ This is a generative model and may occasionally produce inaccurate results. Please verify outputs for critical applications.
51
+
52
+ ## Citation
53
+
54
+ ```bibtex
55
+ @article{lv2023kosmos,
56
+ title={Kosmos-2.5: A multimodal literate model},
57
+ author={Lv, Tengchao and Huang, Yupan and Chen, Jingye and Cui, Lei and Ma, Shuming and Chang, Yaoyao and Huang, Shaohan and Wang, Wenhui and Dong, Li and Luo, Weiyao and others},
58
+ journal={arXiv preprint arXiv:2309.11419},
59
+ year={2023}
60
+ }
61
+ ```
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
6
+ import re
7
+
8
+ # Check if CUDA is available
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
11
+
12
+
13
+ # Initialize models and processors
14
+ @spaces.GPU
15
+ def load_models():
16
+ base_repo = "microsoft/kosmos-2.5"
17
+ chat_repo = "microsoft/kosmos-2.5-chat"
18
+
19
+ base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
20
+ base_repo,
21
+ device_map=device,
22
+ torch_dtype=dtype,
23
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
24
+ )
25
+ base_processor = AutoProcessor.from_pretrained(base_repo)
26
+
27
+ chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
28
+ chat_repo,
29
+ device_map=device,
30
+ torch_dtype=dtype,
31
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
32
+ )
33
+ chat_processor = AutoProcessor.from_pretrained(chat_repo)
34
+
35
+ return base_model, base_processor, chat_model, chat_processor
36
+
37
+ base_model, base_processor, chat_model, chat_processor = load_models()
38
+
39
+ def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
40
+ y = y.replace(prompt, "")
41
+ if "<md>" in prompt:
42
+ return y
43
+
44
+ pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
45
+ bboxs_raw = re.findall(pattern, y)
46
+ lines = re.split(pattern, y)[1:]
47
+ bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
48
+ bboxs = [[int(j) for j in i] for i in bboxs]
49
+
50
+ info = ""
51
+ for i in range(len(lines)):
52
+ if i < len(bboxs):
53
+ box = bboxs[i]
54
+ x0, y0, x1, y1 = box
55
+ if not (x0 >= x1 or y0 >= y1):
56
+ x0 = int(x0 * scale_width)
57
+ y0 = int(y0 * scale_height)
58
+ x1 = int(x1 * scale_width)
59
+ y1 = int(y1 * scale_height)
60
+ info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n"
61
+ return info.strip()
62
+
63
+ @spaces.GPU
64
+ def generate_markdown(image):
65
+ if image is None:
66
+ return "Please upload an image."
67
+
68
+ prompt = "<md>"
69
+ inputs = base_processor(text=prompt, images=image, return_tensors="pt")
70
+
71
+ height, width = inputs.pop("height"), inputs.pop("width")
72
+ raw_width, raw_height = image.size
73
+ scale_height = raw_height / height
74
+ scale_width = raw_width / width
75
+
76
+ inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
77
+ inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
78
+
79
+ with torch.no_grad():
80
+ generated_ids = base_model.generate(
81
+ **inputs,
82
+ max_new_tokens=1024,
83
+ )
84
+
85
+ generated_text = base_processor.batch_decode(generated_ids, skip_special_tokens=True)
86
+ result = generated_text[0].replace(prompt, "").strip()
87
+
88
+ return result
89
+
90
+ @spaces.GPU
91
+ def generate_ocr(image):
92
+ if image is None:
93
+ return "Please upload an image.", None
94
+
95
+ prompt = "<ocr>"
96
+ inputs = base_processor(text=prompt, images=image, return_tensors="pt")
97
+
98
+ height, width = inputs.pop("height"), inputs.pop("width")
99
+ raw_width, raw_height = image.size
100
+ scale_height = raw_height / height
101
+ scale_width = raw_width / width
102
+
103
+ inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
104
+ inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
105
+
106
+ with torch.no_grad():
107
+ generated_ids = base_model.generate(
108
+ **inputs,
109
+ max_new_tokens=1024,
110
+ )
111
+
112
+ generated_text = base_processor.batch_decode(generated_ids, skip_special_tokens=True)
113
+
114
+ # Post-process OCR output
115
+ output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
116
+
117
+ # Create visualization
118
+ from PIL import ImageDraw
119
+ vis_image = image.copy()
120
+ draw = ImageDraw.Draw(vis_image)
121
+
122
+ lines = output_text.split("\n")
123
+ for line in lines:
124
+ if not line.strip():
125
+ continue
126
+ parts = line.split(",")
127
+ if len(parts) >= 8:
128
+ try:
129
+ coords = list(map(int, parts[:8]))
130
+ draw.polygon(coords, outline="red", width=2)
131
+ except:
132
+ continue
133
+
134
+ return output_text, vis_image
135
+
136
+ @spaces.GPU
137
+ def generate_chat_response(image, question):
138
+ if image is None:
139
+ return "Please upload an image."
140
+ if not question.strip():
141
+ return "Please ask a question."
142
+
143
+ template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
144
+ prompt = template.format(question)
145
+
146
+ inputs = chat_processor(text=prompt, images=image, return_tensors="pt")
147
+
148
+ height, width = inputs.pop("height"), inputs.pop("width")
149
+ raw_width, raw_height = image.size
150
+ scale_height = raw_height / height
151
+ scale_width = raw_width / width
152
+
153
+ inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
154
+ inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
155
+
156
+ with torch.no_grad():
157
+ generated_ids = chat_model.generate(
158
+ **inputs,
159
+ max_new_tokens=1024,
160
+ )
161
+
162
+ generated_text = chat_processor.batch_decode(generated_ids, skip_special_tokens=True)
163
+
164
+ # Extract only the assistant's response
165
+ result = generated_text[0]
166
+ if "ASSISTANT:" in result:
167
+ result = result.split("ASSISTANT:")[-1].strip()
168
+
169
+ return result
170
+
171
+ # Create Gradio interface
172
+ with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo:
173
+ gr.Markdown("""
174
+ # KOSMOS-2.5 Document AI Demo
175
+
176
+ Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images!
177
+ This demo showcases three capabilities:
178
+
179
+ 1. **Markdown Generation**: Convert document images to markdown format
180
+ 2. **OCR with Bounding Boxes**: Extract text with spatial coordinates
181
+ 3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat
182
+
183
+ Upload a document image (receipt, form, article, etc.) and try different tasks!
184
+ """)
185
+
186
+ with gr.Tabs():
187
+ # Markdown Generation Tab
188
+ with gr.TabItem("πŸ“ Markdown Generation"):
189
+ with gr.Row():
190
+ with gr.Column():
191
+ md_image = gr.Image(type="pil", label="Upload Document Image")
192
+ md_button = gr.Button("Generate Markdown", variant="primary")
193
+ with gr.Column():
194
+ md_output = gr.Textbox(
195
+ label="Generated Markdown",
196
+ lines=15,
197
+ max_lines=20,
198
+ show_copy_button=True
199
+ )
200
+
201
+ # OCR Tab
202
+ with gr.TabItem("πŸ” OCR with Bounding Boxes"):
203
+ with gr.Row():
204
+ with gr.Column():
205
+ ocr_image = gr.Image(type="pil", label="Upload Document Image")
206
+ ocr_button = gr.Button("Extract Text with Coordinates", variant="primary")
207
+ with gr.Column():
208
+ with gr.Row():
209
+ ocr_text = gr.Textbox(
210
+ label="Extracted Text with Coordinates",
211
+ lines=10,
212
+ show_copy_button=True
213
+ )
214
+ ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
215
+
216
+ # Chat Tab
217
+ with gr.TabItem("πŸ’¬ Document Q&A (Chat)"):
218
+ with gr.Row():
219
+ with gr.Column():
220
+ chat_image = gr.Image(type="pil", label="Upload Document Image")
221
+ chat_question = gr.Textbox(
222
+ label="Ask a question about the document",
223
+ placeholder="e.g., What is the total amount on this receipt?",
224
+ lines=2
225
+ )
226
+ chat_button = gr.Button("Get Answer", variant="primary")
227
+ with gr.Column():
228
+ chat_output = gr.Textbox(
229
+ label="Answer",
230
+ lines=8,
231
+ show_copy_button=True
232
+ )
233
+
234
+ # Event handlers
235
+ md_button.click(
236
+ fn=generate_markdown,
237
+ inputs=[md_image],
238
+ outputs=[md_output]
239
+ )
240
+
241
+ ocr_button.click(
242
+ fn=generate_ocr,
243
+ inputs=[ocr_image],
244
+ outputs=[ocr_text, ocr_vis]
245
+ )
246
+
247
+ chat_button.click(
248
+ fn=generate_chat_response,
249
+ inputs=[chat_image, chat_question],
250
+ outputs=[chat_output]
251
+ )
252
+
253
+ # Examples section
254
+ gr.Markdown("""
255
+ ## Example Use Cases:
256
+ - **Receipts**: Extract itemized information or ask about totals
257
+ - **Forms**: Convert to structured format or answer specific questions
258
+ - **Articles**: Get markdown format or ask about content
259
+ - **Screenshots**: Extract text or get information about specific elements
260
+
261
+ ## Note:
262
+ This is a generative model and may occasionally hallucinate. Results should be verified for accuracy.
263
+ """)
264
+
265
+ if __name__ == "__main__":
266
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch>=2.0.0
3
+ transformers>=4.56.0
4
+ pillow
5
+ requests
6
+ spaces