prithivMLmods commited on
Commit
7aa17ad
·
verified ·
1 Parent(s): 3bc8f4c

update app

Browse files
Files changed (1) hide show
  1. app.py +246 -339
app.py CHANGED
@@ -5,9 +5,7 @@ import json
5
  import time
6
  import asyncio
7
  from threading import Thread
8
- from pathlib import Path
9
- from io import BytesIO
10
- from typing import Optional, Tuple, Dict, Any, Iterable
11
 
12
  import gradio as gr
13
  import spaces
@@ -15,50 +13,42 @@ import torch
15
  import numpy as np
16
  from PIL import Image
17
  import cv2
18
- import requests
19
- import fitz
20
 
21
  from transformers import (
22
- Qwen3VLMoeForConditionalGeneration,
 
 
23
  AutoProcessor,
24
  TextIteratorStreamer,
25
  )
26
  from transformers.image_utils import load_image
27
-
28
  from gradio.themes import Soft
29
  from gradio.themes.utils import colors, fonts, sizes
30
 
31
  # --- Theme and CSS Definition ---
32
 
33
- # Define the Thistle color palette
34
- colors.thistle = colors.Color(
35
- name="thistle",
36
- c50="#F9F5F9",
37
- c100="#F0E8F1",
38
- c200="#E7DBE8",
39
- c300="#DECEE0",
40
- c400="#D2BFD8",
41
- c500="#D8BFD8", # Thistle base color
42
- c600="#B59CB7",
43
- c700="#927996",
44
- c800="#6F5675",
45
- c900="#4C3454",
46
- c950="#291233",
47
- )
48
-
49
- colors.red_gray = colors.Color(
50
- name="red_gray",
51
- c50="#f7eded", c100="#f5dcdc", c200="#efb4b4", c300="#e78f8f",
52
- c400="#d96a6a", c500="#c65353", c600="#b24444", c700="#8f3434",
53
- c800="#732d2d", c900="#5f2626", c950="#4d2020",
54
  )
55
 
56
- class ThistleTheme(Soft):
57
  def __init__(
58
  self,
59
  *,
60
  primary_hue: colors.Color | str = colors.gray,
61
- secondary_hue: colors.Color | str = colors.thistle, # Use the new color
62
  neutral_hue: colors.Color | str = colors.slate,
63
  text_size: sizes.Size | str = sizes.text_lg,
64
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -77,92 +67,24 @@ class ThistleTheme(Soft):
77
  font_mono=font_mono,
78
  )
79
  super().set(
80
- background_fill_primary="*primary_50",
81
- background_fill_primary_dark="*primary_900",
82
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
83
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
84
- button_primary_text_color="black",
85
- button_primary_text_color_hover="white",
86
  button_primary_background_fill="linear-gradient(90deg, *secondary_400, *secondary_500)",
87
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *secondary_600)",
88
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
89
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
90
- button_secondary_text_color="black",
91
- button_secondary_text_color_hover="white",
92
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
93
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
94
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
95
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
96
- slider_color="*secondary_400",
97
- slider_color_dark="*secondary_600",
98
  block_title_text_weight="600",
99
- block_border_width="3px",
100
  block_shadow="*shadow_drop_lg",
101
- button_primary_shadow="*shadow_drop_lg",
102
- button_large_padding="11px",
103
- color_accent_soft="*primary_100",
104
- block_label_background_fill="*primary_200",
105
  )
106
 
107
  # Instantiate the new theme
108
- thistle_theme = ThistleTheme()
109
-
110
- css = """
111
- #main-title h1 {
112
- font-size: 2.3em !important;
113
- }
114
- #output-title h2 {
115
- font-size: 2.1em !important;
116
- }
117
- :root {
118
- --color-grey-50: #f9fafb;
119
- --banner-background: var(--secondary-400);
120
- --banner-text-color: var(--primary-100);
121
- --banner-background-dark: var(--secondary-800);
122
- --banner-text-color-dark: var(--primary-100);
123
- --banner-chrome-height: calc(16px + 43px);
124
- --chat-chrome-height-wide-no-banner: 320px;
125
- --chat-chrome-height-narrow-no-banner: 450px;
126
- --chat-chrome-height-wide: calc(var(--chat-chrome-height-wide-no-banner) + var(--banner-chrome-height));
127
- --chat-chrome-height-narrow: calc(var(--chat-chrome-height-narrow-no-banner) + var(--banner-chrome-height));
128
- }
129
- .banner-message { background-color: var(--banner-background); padding: 5px; margin: 0; border-radius: 5px; border: none; }
130
- .banner-message-text { font-size: 13px; font-weight: bolder; color: var(--banner-text-color) !important; }
131
- body.dark .banner-message { background-color: var(--banner-background-dark) !important; }
132
- body.dark .gradio-container .contain .banner-message .banner-message-text { color: var(--banner-text-color-dark) !important; }
133
- .toast-body { background-color: var(--color-grey-50); }
134
- .html-container:has(.css-styles) { padding: 0; margin: 0; }
135
- .css-styles { height: 0; }
136
- .model-message { text-align: end; }
137
- .model-dropdown-container { display: flex; align-items: center; gap: 10px; padding: 0; }
138
- .user-input-container .multimodal-textbox{ border: none !important; }
139
- .control-button { height: 51px; }
140
- button.cancel { border: var(--button-border-width) solid var(--button-cancel-border-color); background: var(--button-cancel-background-fill); color: var(--button-cancel-text-color); box-shadow: var(--button-cancel-shadow); }
141
- button.cancel:hover, .cancel[disabled] { background: var(--button-cancel-background-fill-hover); color: var(--button-cancel-text-color-hover); }
142
- .opt-out-message { top: 8px; }
143
- .opt-out-message .html-container, .opt-out-checkbox label { font-size: 14px !important; padding: 0 !important; margin: 0 !important; color: var(--neutral-400) !important; }
144
- div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; max-height: 900px !important; }
145
- div.no-padding { padding: 0 !important; }
146
- @media (max-width: 1280px) { div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; } }
147
- @media (max-width: 1024px) {
148
- .responsive-row { flex-direction: column; }
149
- .model-message { text-align: start; font-size: 10px !important; }
150
- .model-dropdown-container { flex-direction: column; align-items: flex-start; }
151
- div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-narrow)) !important; }
152
- }
153
- @media (max-width: 400px) {
154
- .responsive-row { flex-direction: column; }
155
- .model-message { text-align: start; font-size: 10px !important; }
156
- .model-dropdown-container { flex-direction: column; align-items: flex-start; }
157
- div.block.chatbot { max-height: 360px !important; }
158
- }
159
- @media (max-height: 932px) { .chatbot { max-height: 500px !important; } }
160
- @media (max-height: 1280px) { div.block.chatbot { max-height: 800px !important; } }
161
- """
162
 
163
- MAX_MAX_NEW_TOKENS = 4096
 
164
  DEFAULT_MAX_NEW_TOKENS = 1024
165
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
166
 
167
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
168
  print("torch.__version__ =", torch.__version__)
@@ -172,127 +94,129 @@ print("cuda device count:", torch.cuda.device_count())
172
  if torch.cuda.is_available():
173
  print("current device:", torch.cuda.current_device())
174
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
 
175
  print("Using device:", device)
176
 
177
- MODEL_ID_Q3VL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
178
- processor_q3vl = AutoProcessor.from_pretrained(MODEL_ID_Q3VL, trust_remote_code=True, use_fast=False)
179
- model_q3vl = Qwen3VLMoeForConditionalGeneration.from_pretrained(
180
- MODEL_ID_Q3VL,
 
 
181
  trust_remote_code=True,
182
- dtype=torch.float16
183
  ).to(device).eval()
184
 
185
- def extract_gif_frames(gif_path: str):
186
- if not gif_path:
187
- return []
188
- with Image.open(gif_path) as gif:
189
- total_frames = gif.n_frames
190
- frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
191
- frames = []
192
- for i in frame_indices:
193
- gif.seek(i)
194
- frames.append(gif.convert("RGB").copy())
195
- return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def downsample_video(video_path):
 
 
 
 
198
  vidcap = cv2.VideoCapture(video_path)
199
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
200
  frames = []
201
- frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
202
  for i in frame_indices:
203
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
204
  success, image = vidcap.read()
205
  if success:
206
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
207
  pil_image = Image.fromarray(image)
208
- frames.append(pil_image)
 
209
  vidcap.release()
210
  return frames
211
 
212
- def convert_pdf_to_images(file_path: str, dpi: int = 200):
213
- if not file_path:
214
- return []
215
- images = []
216
- pdf_document = fitz.open(file_path)
217
- zoom = dpi / 72.0
218
- mat = fitz.Matrix(zoom, zoom)
219
- for page_num in range(len(pdf_document)):
220
- page = pdf_document.load_page(page_num)
221
- pix = page.get_pixmap(matrix=mat)
222
- img_data = pix.tobytes("png")
223
- images.append(Image.open(BytesIO(img_data)))
224
- pdf_document.close()
225
- return images
226
-
227
- def get_initial_pdf_state() -> Dict[str, Any]:
228
- return {"pages": [], "total_pages": 0, "current_page_index": 0}
229
-
230
- def load_and_preview_pdf(file_path: Optional[str]) -> Tuple[Optional[Image.Image], Dict[str, Any], str]:
231
- state = get_initial_pdf_state()
232
- if not file_path:
233
- return None, state, '<div style="text-align:center;">No file loaded</div>'
234
- try:
235
- pages = convert_pdf_to_images(file_path)
236
- if not pages:
237
- return None, state, '<div style="text-align:center;">Could not load file</div>'
238
- state["pages"] = pages
239
- state["total_pages"] = len(pages)
240
- page_info_html = f'<div style="text-align:center;">Page 1 / {state["total_pages"]}</div>'
241
- return pages[0], state, page_info_html
242
- except Exception as e:
243
- return None, state, f'<div style="text-align:center;">Failed to load preview: {e}</div>'
244
-
245
- def navigate_pdf_page(direction: str, state: Dict[str, Any]):
246
- if not state or not state["pages"]:
247
- return None, state, '<div style="text-align:center;">No file loaded</div>'
248
- current_index = state["current_page_index"]
249
- total_pages = state["total_pages"]
250
- if direction == "prev":
251
- new_index = max(0, current_index - 1)
252
- elif direction == "next":
253
- new_index = min(total_pages - 1, current_index + 1)
254
  else:
255
- new_index = current_index
256
- state["current_page_index"] = new_index
257
- image_preview = state["pages"][new_index]
258
- page_info_html = f'<div style="text-align:center;">Page {new_index + 1} / {total_pages}</div>'
259
- return image_preview, state, page_info_html
260
 
261
- @spaces.GPU
262
- def generate_image(text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
263
  if image is None:
264
  yield "Please upload an image.", "Please upload an image."
265
  return
266
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
267
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
268
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
269
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
270
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
271
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
272
- thread.start()
273
- buffer = ""
274
- for new_text in streamer:
275
- buffer += new_text
276
- time.sleep(0.01)
277
- yield buffer, buffer
278
 
279
- @spaces.GPU
280
- def generate_video(text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
281
- if video_path is None:
282
- yield "Please upload a video.", "Please upload a video."
283
- return
284
- frames = downsample_video(video_path)
285
- if not frames:
286
- yield "Could not process video.", "Could not process video."
287
- return
288
- messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
289
- for frame in frames:
290
- messages[0]["content"].insert(0, {"type": "image"})
291
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
292
- inputs = processor_q3vl(text=[prompt_full], images=frames, return_tensors="pt", padding=True).to(device)
293
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
294
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
295
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
 
 
296
  thread.start()
297
  buffer = ""
298
  for new_text in streamer:
@@ -302,73 +226,69 @@ def generate_video(text: str, video_path: str, max_new_tokens: int = 1024, tempe
302
  yield buffer, buffer
303
 
304
  @spaces.GPU
305
- def generate_pdf(text: str, state: Dict[str, Any], max_new_tokens: int = 2048, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
306
- if not state or not state["pages"]:
307
- yield "Please upload a PDF file first.", "Please upload a PDF file first."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  return
309
- page_images = state["pages"]
310
- full_response = ""
311
- for i, image in enumerate(page_images):
312
- page_header = f"--- Page {i+1}/{len(page_images)} ---\n"
313
- yield full_response + page_header, full_response + page_header
314
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
315
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
316
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
317
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
318
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
319
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
320
- thread.start()
321
- page_buffer = ""
322
- for new_text in streamer:
323
- page_buffer += new_text
324
- yield full_response + page_header + page_buffer, full_response + page_header + page_buffer
325
- time.sleep(0.01)
326
- full_response += page_header + page_buffer + "\n\n"
327
 
328
- @spaces.GPU
329
- def generate_caption(image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
330
- if image is None:
331
- yield "Please upload an image to caption.", "Please upload an image to caption."
332
  return
333
- system_prompt = (
334
- "You are an AI assistant that rigorously follows this response protocol: For every input image, your primary "
335
- "task is to write a precise caption that captures the essence of the image in clear, concise, and contextually "
336
- "accurate language. Along with the caption, provide a structured set of attributes describing the visual "
337
- "elements, including details such as objects, people, actions, colors, environment, mood, and other notable "
338
- "characteristics. Ensure captions are precise, neutral, and descriptive, avoiding unnecessary elaboration or "
339
- "subjective interpretation unless explicitly required. Do not reference the rules or instructions in the output; "
340
- "only return the formatted caption, attributes, and class_name."
341
- )
342
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": system_prompt}]}]
343
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
344
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
345
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
346
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
347
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
348
- thread.start()
349
- buffer = ""
350
- for new_text in streamer:
351
- buffer += new_text
352
- time.sleep(0.01)
353
- yield buffer, buffer
354
 
355
- @spaces.GPU
356
- def generate_gif(text: str, gif_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
357
- if gif_path is None:
358
- yield "Please upload a GIF.", "Please upload a GIF."
359
- return
360
- frames = extract_gif_frames(gif_path)
361
- if not frames:
362
- yield "Could not process GIF.", "Could not process GIF."
363
- return
364
  messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
365
- for frame in frames:
366
  messages[0]["content"].insert(0, {"type": "image"})
367
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
368
- inputs = processor_q3vl(text=[prompt_full], images=frames, return_tensors="pt", padding=True).to(device)
369
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
370
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
371
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  thread.start()
373
  buffer = ""
374
  for new_text in streamer:
@@ -377,98 +297,85 @@ def generate_gif(text: str, gif_path: str, max_new_tokens: int = 1024, temperatu
377
  time.sleep(0.01)
378
  yield buffer, buffer
379
 
380
- image_examples = [["Perform OCR on the image...", "examples/images/1.jpg"],
381
- ["Caption the image. Describe the safety measures shown in the image. Conclude whether the situation is (safe or unsafe)...", "examples/images/2.jpg"],
382
- ["Solve the problem...", "examples/images/3.png"]]
383
- video_examples = [["Explain the Ad video in detail.", "examples/videos/1.mp4"],
384
- ["Explain the video in detail.", "examples/videos/2.mp4"]]
385
- pdf_examples = [["Extract the content precisely.", "examples/pdfs/doc1.pdf"],
386
- ["Analyze and provide a short report.", "examples/pdfs/doc2.pdf"]]
387
- gif_examples = [["Describe this GIF.", "examples/gifs/1.gif"],
388
- ["Describe this GIF.", "examples/gifs/2.gif"]]
389
- #caption_examples = [["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG"],
390
- # ["examples/captions/2.png"], ["examples/captions/3.png"]]
391
- caption_examples = [["examples/captions/1.JPG"],
392
- ["examples/captions/2.jpeg"], ["examples/captions/3.jpeg"]]
 
 
 
 
 
 
 
 
 
 
393
 
394
- with gr.Blocks(theme=thistle_theme, css=css) as demo:
395
- pdf_state = gr.State(value=get_initial_pdf_state())
396
- gr.Markdown("# **Qwen-3VL:Multimodal**", elem_id="main-title")
397
  with gr.Row():
398
  with gr.Column(scale=2):
399
  with gr.Tabs():
400
  with gr.TabItem("Image Inference"):
401
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
402
- image_upload = gr.Image(type="pil", label="Image", height=290)
403
  image_submit = gr.Button("Submit", variant="primary")
404
- gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
405
-
 
 
406
  with gr.TabItem("Video Inference"):
407
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
408
- video_upload = gr.Video(label="Video", height=290)
409
  video_submit = gr.Button("Submit", variant="primary")
410
- gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
411
-
412
- with gr.TabItem("PDF Inference"):
413
- with gr.Row():
414
- with gr.Column(scale=1):
415
- pdf_query = gr.Textbox(label="Query Input", placeholder="e.g., 'Summarize this document'")
416
- pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
417
- pdf_submit = gr.Button("Submit", variant="primary")
418
- with gr.Column(scale=1):
419
- pdf_preview_img = gr.Image(label="PDF Preview", height=290)
420
- with gr.Row():
421
- prev_page_btn = gr.Button("◀ Previous")
422
- page_info = gr.HTML('<div style="text-align:center;">No file loaded</div>')
423
- next_page_btn = gr.Button("Next ▶")
424
- gr.Examples(examples=pdf_examples, inputs=[pdf_query, pdf_upload])
425
-
426
- with gr.TabItem("Gif Inference"):
427
- gif_query = gr.Textbox(label="Query Input", placeholder="e.g., 'What is happening in this gif?'")
428
- gif_upload = gr.Image(type="filepath", label="Upload GIF", height=290)
429
- gif_submit = gr.Button("Submit", variant="primary")
430
- gr.Examples(examples=gif_examples, inputs=[gif_query, gif_upload])
431
-
432
- with gr.TabItem("Caption"):
433
- caption_image_upload = gr.Image(type="pil", label="Image to Caption", height=290)
434
- caption_submit = gr.Button("Generate Caption", variant="primary")
435
- gr.Examples(examples=caption_examples, inputs=[caption_image_upload])
436
-
437
  with gr.Accordion("Advanced options", open=False):
438
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
439
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
440
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
441
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
442
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
443
-
444
  with gr.Column(scale=3):
445
- gr.Markdown("## Output", elem_id="output-title")
446
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=14, show_copy_button=True)
447
- with gr.Accordion("(Result.md)", open=False):
448
- markdown_output = gr.Markdown(label="(Result.Md)", latex_delimiters=[
449
- {"left": "$$", "right": "$$", "display": True},
450
- {"left": "$", "right": "$", "display": False}
451
- ])
452
-
453
- image_submit.click(fn=generate_image,
454
- inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
455
- outputs=[output, markdown_output])
456
- video_submit.click(fn=generate_video,
457
- inputs=[video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
458
- outputs=[output, markdown_output])
459
- pdf_submit.click(fn=generate_pdf,
460
- inputs=[pdf_query, pdf_state, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
461
- outputs=[output, markdown_output])
462
- gif_submit.click(fn=generate_gif,
463
- inputs=[gif_query, gif_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
464
- outputs=[output, markdown_output])
465
- caption_submit.click(fn=generate_caption,
466
- inputs=[caption_image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
467
- outputs=[output, markdown_output])
468
-
469
- pdf_upload.change(fn=load_and_preview_pdf, inputs=[pdf_upload], outputs=[pdf_preview_img, pdf_state, page_info])
470
- prev_page_btn.click(fn=lambda s: navigate_pdf_page("prev", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
471
- next_page_btn.click(fn=lambda s: navigate_pdf_page("next", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
472
 
473
  if __name__ == "__main__":
474
  demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)
 
5
  import time
6
  import asyncio
7
  from threading import Thread
8
+ from typing import Iterable
 
 
9
 
10
  import gradio as gr
11
  import spaces
 
13
  import numpy as np
14
  from PIL import Image
15
  import cv2
 
 
16
 
17
  from transformers import (
18
+ Qwen2VLForConditionalGeneration,
19
+ Qwen2_5_VLForConditionalGeneration,
20
+ AutoModelForImageTextToText,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
  )
24
  from transformers.image_utils import load_image
 
25
  from gradio.themes import Soft
26
  from gradio.themes.utils import colors, fonts, sizes
27
 
28
  # --- Theme and CSS Definition ---
29
 
30
+ # Define the new PaleGreen color palette
31
+ colors.pale_green = colors.Color(
32
+ name="pale_green",
33
+ c50="#F3FEF3",
34
+ c100="#E7FDE7",
35
+ c200="#D5FCD5",
36
+ c300="#C4FBC4",
37
+ c400="#B1FBAF",
38
+ c500="#98FB98", # PaleGreen base color
39
+ c600="#89E289",
40
+ c700="#7AC87A",
41
+ c800="#6BAF6B",
42
+ c900="#5B965B",
43
+ c950="#4C7D4C",
 
 
 
 
 
 
 
44
  )
45
 
46
+ class PaleGreenTheme(Soft):
47
  def __init__(
48
  self,
49
  *,
50
  primary_hue: colors.Color | str = colors.gray,
51
+ secondary_hue: colors.Color | str = colors.pale_green,
52
  neutral_hue: colors.Color | str = colors.slate,
53
  text_size: sizes.Size | str = sizes.text_lg,
54
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
67
  font_mono=font_mono,
68
  )
69
  super().set(
 
 
 
 
 
 
70
  button_primary_background_fill="linear-gradient(90deg, *secondary_400, *secondary_500)",
71
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *secondary_600)",
72
+ button_primary_text_color="black",
73
+ slider_color="*secondary_500",
 
 
 
 
 
 
 
 
74
  block_title_text_weight="600",
75
+ block_border_width="2px",
76
  block_shadow="*shadow_drop_lg",
 
 
 
 
77
  )
78
 
79
  # Instantiate the new theme
80
+ pale_green_theme = PaleGreenTheme()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # Constants for text generation
83
+ MAX_MAX_NEW_TOKENS = 2048
84
  DEFAULT_MAX_NEW_TOKENS = 1024
85
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
86
+
87
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
88
 
89
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
90
  print("torch.__version__ =", torch.__version__)
 
94
  if torch.cuda.is_available():
95
  print("current device:", torch.cuda.current_device())
96
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
97
+
98
  print("Using device:", device)
99
 
100
+ # --- Model Loading ---
101
+ # Load Nanonets-OCR-s
102
+ MODEL_ID_V = "nanonets/Nanonets-OCR-s"
103
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
104
+ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
105
+ MODEL_ID_V,
106
  trust_remote_code=True,
107
+ torch_dtype=torch.float16
108
  ).to(device).eval()
109
 
110
+ # Load Qwen2-VL-OCR-2B-Instruct
111
+ MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
112
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
113
+ model_x = Qwen2VLForConditionalGeneration.from_pretrained(
114
+ MODEL_ID_X,
115
+ trust_remote_code=True,
116
+ torch_dtype=torch.float16
117
+ ).to(device).eval()
118
+
119
+ # Load Aya-Vision-8b
120
+ MODEL_ID_A = "CohereForAI/aya-vision-8b"
121
+ processor_a = AutoProcessor.from_pretrained(MODEL_ID_A, trust_remote_code=True)
122
+ model_a = AutoModelForImageTextToText.from_pretrained(
123
+ MODEL_ID_A,
124
+ trust_remote_code=True,
125
+ torch_dtype=torch.float16
126
+ ).to(device).eval()
127
+
128
+ # Load olmOCR-7B-0725
129
+ MODEL_ID_W = "allenai/olmOCR-7B-0725"
130
+ processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
131
+ model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
132
+ MODEL_ID_W,
133
+ trust_remote_code=True,
134
+ torch_dtype=torch.float16
135
+ ).to(device).eval()
136
+
137
+ # Load RolmOCR
138
+ MODEL_ID_M = "reducto/RolmOCR"
139
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
140
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
141
+ MODEL_ID_M,
142
+ trust_remote_code=True,
143
+ torch_dtype=torch.float16
144
+ ).to(device).eval()
145
 
146
  def downsample_video(video_path):
147
+ """
148
+ Downsamples the video to evenly spaced frames.
149
+ Each frame is returned as a PIL image along with its timestamp.
150
+ """
151
  vidcap = cv2.VideoCapture(video_path)
152
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
153
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
154
  frames = []
155
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
156
  for i in frame_indices:
157
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
158
  success, image = vidcap.read()
159
  if success:
160
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
161
  pil_image = Image.fromarray(image)
162
+ timestamp = round(i / fps, 2)
163
+ frames.append((pil_image, timestamp))
164
  vidcap.release()
165
  return frames
166
 
167
+ @spaces.GPU
168
+ def generate_image(model_name: str, text: str, image: Image.Image,
169
+ max_new_tokens: int = 1024,
170
+ temperature: float = 0.6,
171
+ top_p: float = 0.9,
172
+ top_k: int = 50,
173
+ repetition_penalty: float = 1.2):
174
+ """
175
+ Generates responses using the selected model for image input.
176
+ Yields raw text and Markdown-formatted text.
177
+ """
178
+ if model_name == "RolmOCR-7B":
179
+ processor = processor_m
180
+ model = model_m
181
+ elif model_name == "Qwen2-VL-OCR-2B":
182
+ processor = processor_x
183
+ model = model_x
184
+ elif model_name == "Nanonets-OCR-s":
185
+ processor = processor_v
186
+ model = model_v
187
+ elif model_name == "Aya-Vision-8B":
188
+ processor = processor_a
189
+ model = model_a
190
+ elif model_name == "olmOCR-7B-0725":
191
+ processor = processor_w
192
+ model = model_w
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  else:
194
+ yield "Invalid model selected.", "Invalid model selected."
195
+ return
 
 
 
196
 
 
 
197
  if image is None:
198
  yield "Please upload an image.", "Please upload an image."
199
  return
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ messages = [{
202
+ "role": "user",
203
+ "content": [
204
+ {"type": "image"},
205
+ {"type": "text", "text": text},
206
+ ]
207
+ }]
208
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
209
+ inputs = processor(
210
+ text=[prompt_full],
211
+ images=[image],
212
+ return_tensors="pt",
213
+ padding=True,
214
+ truncation=True,
215
+ max_length=MAX_INPUT_TOKEN_LENGTH
216
+ ).to(device)
217
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
218
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
219
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
220
  thread.start()
221
  buffer = ""
222
  for new_text in streamer:
 
226
  yield buffer, buffer
227
 
228
  @spaces.GPU
229
+ def generate_video(model_name: str, text: str, video_path: str,
230
+ max_new_tokens: int = 1024,
231
+ temperature: float = 0.6,
232
+ top_p: float = 0.9,
233
+ top_k: int = 50,
234
+ repetition_penalty: float = 1.2):
235
+ """
236
+ Generates responses using the selected model for video input.
237
+ Yields raw text and Markdown-formatted text.
238
+ """
239
+ if model_name == "RolmOCR-7B":
240
+ processor = processor_m
241
+ model = model_m
242
+ elif model_name == "Qwen2-VL-OCR-2B":
243
+ processor = processor_x
244
+ model = model_x
245
+ elif model_name == "Nanonets-OCR-s":
246
+ processor = processor_v
247
+ model = model_v
248
+ elif model_name == "Aya-Vision-8B":
249
+ processor = processor_a
250
+ model = model_a
251
+ elif model_name == "olmOCR-7B-0725":
252
+ processor = processor_w
253
+ model = model_w
254
+ else:
255
+ yield "Invalid model selected.", "Invalid model selected."
256
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ if video_path is None:
259
+ yield "Please upload a video.", "Please upload a video."
 
 
260
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ frames_with_ts = downsample_video(video_path)
263
+ images_for_processor = [frame for frame, ts in frames_with_ts]
264
+
 
 
 
 
 
 
265
  messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
266
+ for frame in images_for_processor:
267
  messages[0]["content"].insert(0, {"type": "image"})
268
+
269
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
270
+
271
+ inputs = processor(
272
+ text=[prompt_full],
273
+ images=images_for_processor,
274
+ return_tensors="pt",
275
+ padding=True,
276
+ truncation=True,
277
+ max_length=MAX_INPUT_TOKEN_LENGTH
278
+ ).to(device)
279
+
280
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
281
+ generation_kwargs = {
282
+ **inputs,
283
+ "streamer": streamer,
284
+ "max_new_tokens": max_new_tokens,
285
+ "do_sample": True,
286
+ "temperature": temperature,
287
+ "top_p": top_p,
288
+ "top_k": top_k,
289
+ "repetition_penalty": repetition_penalty,
290
+ }
291
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
292
  thread.start()
293
  buffer = ""
294
  for new_text in streamer:
 
297
  time.sleep(0.01)
298
  yield buffer, buffer
299
 
300
+ # Define examples for image and video inference
301
+ image_examples = [
302
+ ["Extract the full page.", "images/ocr.png"],
303
+ ["Extract the content.", "images/4.png"],
304
+ ["Explain the scene.", "images/3.jpg"],
305
+ ["Convert this page to doc [table] precisely for markdown.", "images/0.png"],
306
+ ["Perform OCR on the Image.", "images/1.jpg"],
307
+ ["Extract the table content.", "images/2.png"]
308
+ ]
309
+
310
+ video_examples = [
311
+ ["Explain the Ad in Detail.", "videos/1.mp4"],
312
+ ["Identify the main actions in the cartoon video.", "videos/2.mp4"]
313
+ ]
314
+
315
+ css = """
316
+ #main-title h1 {
317
+ font-size: 2.3em !important;
318
+ }
319
+ #output-title h2 {
320
+ font-size: 2.1em !important;
321
+ }
322
+ """
323
 
324
+ # Create the Gradio Interface
325
+ with gr.Blocks(css=css, theme=pale_green_theme) as demo:
326
+ gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
327
  with gr.Row():
328
  with gr.Column(scale=2):
329
  with gr.Tabs():
330
  with gr.TabItem("Image Inference"):
331
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
332
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
333
  image_submit = gr.Button("Submit", variant="primary")
334
+ gr.Examples(
335
+ examples=image_examples,
336
+ inputs=[image_query, image_upload]
337
+ )
338
  with gr.TabItem("Video Inference"):
339
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
340
+ video_upload = gr.Video(label="Upload Video", height=290)
341
  video_submit = gr.Button("Submit", variant="primary")
342
+ gr.Examples(
343
+ examples=video_examples,
344
+ inputs=[video_query, video_upload]
345
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  with gr.Accordion("Advanced options", open=False):
347
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
348
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
349
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
350
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
351
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
352
+
353
  with gr.Column(scale=3):
354
+ gr.Markdown("## Output", elem_id="output-title")
355
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
356
+ with gr.Accordion("(Result.md)", open=False):
357
+ markdown_output = gr.Markdown(label="(Result.Md)", latex_delimiters=[
358
+ {"left": "$$", "right": "$$", "display": True},
359
+ {"left": "$", "right": "$", "display": False}
360
+ ])
361
+
362
+ model_choice = gr.Radio(
363
+ choices=["olmOCR-7B-0725", "Nanonets-OCR-s", "RolmOCR-7B",
364
+ "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
365
+ label="Select Model",
366
+ value="olmOCR-7B-0725"
367
+ )
368
+
369
+ image_submit.click(
370
+ fn=generate_image,
371
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
372
+ outputs=[output, markdown_output]
373
+ )
374
+ video_submit.click(
375
+ fn=generate_video,
376
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
377
+ outputs=[output, markdown_output]
378
+ )
 
 
379
 
380
  if __name__ == "__main__":
381
  demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)