prithivMLmods commited on
Commit
e1ef66b
·
verified ·
1 Parent(s): 639d6a2

update app

Browse files
Files changed (1) hide show
  1. app.py +213 -53
app.py CHANGED
@@ -12,7 +12,7 @@ import spaces
12
  import torch
13
  import numpy as np
14
  from PIL import Image
15
- # cv2 is no longer needed as video processing is removed
16
 
17
  from transformers import (
18
  Qwen2VLForConditionalGeneration,
@@ -27,29 +27,35 @@ from gradio.themes.utils import colors, fonts, sizes
27
 
28
  # --- Theme and CSS Definition ---
29
 
30
- # Define the new SpringGreen color palette
31
- colors.spring_green = colors.Color(
32
- name="spring_green",
33
- c50="#E5FFF2",
34
- c100="#CCFFEC",
35
- c200="#99FFD9",
36
- c300="#66FFC6",
37
- c400="#33FFB3",
38
- c500="#00FF7F", # SpringGreen base color
39
- c600="#00E672",
40
- c700="#00CC66",
41
- c800="#00B359",
42
- c900="#00994D",
43
- c950="#008040",
44
  )
45
 
 
 
 
 
 
 
46
 
47
- class SpringGreenTheme(Soft):
48
  def __init__(
49
  self,
50
  *,
51
  primary_hue: colors.Color | str = colors.gray,
52
- secondary_hue: colors.Color | str = colors.spring_green, # Use the new color
53
  neutral_hue: colors.Color | str = colors.slate,
54
  text_size: sizes.Size | str = sizes.text_lg,
55
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -78,6 +84,12 @@ class SpringGreenTheme(Soft):
78
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *secondary_600)",
79
  button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
80
  button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
 
 
 
 
 
 
81
  slider_color="*secondary_400",
82
  slider_color_dark="*secondary_600",
83
  block_title_text_weight="600",
@@ -90,8 +102,7 @@ class SpringGreenTheme(Soft):
90
  )
91
 
92
  # Instantiate the new theme
93
- spring_green_theme = SpringGreenTheme()
94
-
95
 
96
  css = """
97
  #main-title h1 {
@@ -100,12 +111,56 @@ css = """
100
  #output-title h2 {
101
  font-size: 2.1em !important;
102
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  """
104
 
105
  # Constants for text generation
106
  MAX_MAX_NEW_TOKENS = 2048
107
  DEFAULT_MAX_NEW_TOKENS = 1024
108
- # Increased max_length to accommodate more complex inputs
109
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
110
 
111
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -153,7 +208,7 @@ model_a = AutoModelForImageTextToText.from_pretrained(
153
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
154
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
155
  model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
156
- MODEL_ID_W,
157
  trust_remote_code=True,
158
  torch_dtype=torch.float16
159
  ).to(device).eval()
@@ -167,6 +222,27 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
167
  torch_dtype=torch.float16
168
  ).to(device).eval()
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  @spaces.GPU
172
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -210,9 +286,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
210
  ]
211
  }]
212
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
213
-
214
- # FIX: Set truncation to False and rely on the model's context length.
215
- # The increased MAX_INPUT_TOKEN_LENGTH at the top also helps.
216
  inputs = processor(
217
  text=[prompt_full],
218
  images=[image],
@@ -231,53 +306,138 @@ def generate_image(model_name: str, text: str, image: Image.Image,
231
  time.sleep(0.01)
232
  yield buffer, buffer
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- # Define examples for image inference
236
  image_examples = [
237
- ["Extract the full page.", "images/ocr.png"],
238
- ["Extract the content.", "images/4.png"],
239
  ["Convert this page to doc [table] precisely for markdown.", "images/0.png"]
240
  ]
241
 
 
 
 
242
 
243
  # Create the Gradio Interface
244
- with gr.Blocks(css=css, theme=spring_green_theme) as demo:
245
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
246
  with gr.Row():
247
  with gr.Column(scale=2):
248
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
249
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
250
- image_submit = gr.Button("Submit", variant="primary")
251
- gr.Examples(
252
- examples=image_examples,
253
- inputs=[image_query, image_upload]
254
- )
 
 
 
 
 
 
 
 
 
 
 
255
  with gr.Accordion("Advanced options", open=False):
256
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1,
257
- value=DEFAULT_MAX_NEW_TOKENS)
258
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
259
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
260
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
261
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05,
262
- value=1.2)
263
-
264
  with gr.Column(scale=3):
265
- gr.Markdown("## Output", elem_id="output-title")
266
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
267
- with gr.Accordion("(Result.md)", open=False):
268
- markdown_output = gr.Markdown(label="(Result.Md)")
269
-
270
- model_choice = gr.Radio(
271
- choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
272
  "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
273
- label="Select Model",
274
- value="Nanonets-OCR2-3B"
275
- )
276
-
277
  image_submit.click(
278
  fn=generate_image,
279
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k,
280
- repetition_penalty],
 
 
 
 
281
  outputs=[output, markdown_output]
282
  )
283
 
 
12
  import torch
13
  import numpy as np
14
  from PIL import Image
15
+ import cv2
16
 
17
  from transformers import (
18
  Qwen2VLForConditionalGeneration,
 
27
 
28
  # --- Theme and CSS Definition ---
29
 
30
+ # Define the Thistle color palette
31
+ colors.thistle = colors.Color(
32
+ name="thistle",
33
+ c50="#F9F5F9",
34
+ c100="#F0E8F1",
35
+ c200="#E7DBE8",
36
+ c300="#DECEE0",
37
+ c400="#D2BFD8",
38
+ c500="#D8BFD8", # Thistle base color
39
+ c600="#B59CB7",
40
+ c700="#927996",
41
+ c800="#6F5675",
42
+ c900="#4C3454",
43
+ c950="#291233",
44
  )
45
 
46
+ colors.red_gray = colors.Color(
47
+ name="red_gray",
48
+ c50="#f7eded", c100="#f5dcdc", c200="#efb4b4", c300="#e78f8f",
49
+ c400="#d96a6a", c500="#c65353", c600="#b24444", c700="#8f3434",
50
+ c800="#732d2d", c900="#5f2626", c950="#4d2020",
51
+ )
52
 
53
+ class ThistleTheme(Soft):
54
  def __init__(
55
  self,
56
  *,
57
  primary_hue: colors.Color | str = colors.gray,
58
+ secondary_hue: colors.Color | str = colors.thistle, # Use the new color
59
  neutral_hue: colors.Color | str = colors.slate,
60
  text_size: sizes.Size | str = sizes.text_lg,
61
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
84
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *secondary_600)",
85
  button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
86
  button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
87
+ button_secondary_text_color="black",
88
+ button_secondary_text_color_hover="white",
89
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
90
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
91
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
92
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
93
  slider_color="*secondary_400",
94
  slider_color_dark="*secondary_600",
95
  block_title_text_weight="600",
 
102
  )
103
 
104
  # Instantiate the new theme
105
+ thistle_theme = ThistleTheme()
 
106
 
107
  css = """
108
  #main-title h1 {
 
111
  #output-title h2 {
112
  font-size: 2.1em !important;
113
  }
114
+ :root {
115
+ --color-grey-50: #f9fafb;
116
+ --banner-background: var(--secondary-400);
117
+ --banner-text-color: var(--primary-100);
118
+ --banner-background-dark: var(--secondary-800);
119
+ --banner-text-color-dark: var(--primary-100);
120
+ --banner-chrome-height: calc(16px + 43px);
121
+ --chat-chrome-height-wide-no-banner: 320px;
122
+ --chat-chrome-height-narrow-no-banner: 450px;
123
+ --chat-chrome-height-wide: calc(var(--chat-chrome-height-wide-no-banner) + var(--banner-chrome-height));
124
+ --chat-chrome-height-narrow: calc(var(--chat-chrome-height-narrow-no-banner) + var(--banner-chrome-height));
125
+ }
126
+ .banner-message { background-color: var(--banner-background); padding: 5px; margin: 0; border-radius: 5px; border: none; }
127
+ .banner-message-text { font-size: 13px; font-weight: bolder; color: var(--banner-text-color) !important; }
128
+ body.dark .banner-message { background-color: var(--banner-background-dark) !important; }
129
+ body.dark .gradio-container .contain .banner-message .banner-message-text { color: var(--banner-text-color-dark) !important; }
130
+ .toast-body { background-color: var(--color-grey-50); }
131
+ .html-container:has(.css-styles) { padding: 0; margin: 0; }
132
+ .css-styles { height: 0; }
133
+ .model-message { text-align: end; }
134
+ .model-dropdown-container { display: flex; align-items: center; gap: 10px; padding: 0; }
135
+ .user-input-container .multimodal-textbox{ border: none !important; }
136
+ .control-button { height: 51px; }
137
+ button.cancel { border: var(--button-border-width) solid var(--button-cancel-border-color); background: var(--button-cancel-background-fill); color: var(--button-cancel-text-color); box-shadow: var(--button-cancel-shadow); }
138
+ button.cancel:hover, .cancel[disabled] { background: var(--button-cancel-background-fill-hover); color: var(--button-cancel-text-color-hover); }
139
+ .opt-out-message { top: 8px; }
140
+ .opt-out-message .html-container, .opt-out-checkbox label { font-size: 14px !important; padding: 0 !important; margin: 0 !important; color: var(--neutral-400) !important; }
141
+ div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; max-height: 900px !important; }
142
+ div.no-padding { padding: 0 !important; }
143
+ @media (max-width: 1280px) { div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; } }
144
+ @media (max-width: 1024px) {
145
+ .responsive-row { flex-direction: column; }
146
+ .model-message { text-align: start; font-size: 10px !important; }
147
+ .model-dropdown-container { flex-direction: column; align-items: flex-start; }
148
+ div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-narrow)) !important; }
149
+ }
150
+ @media (max-width: 400px) {
151
+ .responsive-row { flex-direction: column; }
152
+ .model-message { text-align: start; font-size: 10px !important; }
153
+ .model-dropdown-container { flex-direction: column; align-items: flex-start; }
154
+ div.block.chatbot { max-height: 360px !important; }
155
+ }
156
+ @media (max-height: 932px) { .chatbot { max-height: 500px !important; } }
157
+ @media (max-height: 1280px) { div.block.chatbot { max-height: 800px !important; } }
158
  """
159
 
160
  # Constants for text generation
161
  MAX_MAX_NEW_TOKENS = 2048
162
  DEFAULT_MAX_NEW_TOKENS = 1024
163
+ # Increased max_length to accommodate more complex inputs, especially with multiple images
164
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
165
 
166
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
208
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
209
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
210
  model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
211
+ MODEL_ID_W,
212
  trust_remote_code=True,
213
  torch_dtype=torch.float16
214
  ).to(device).eval()
 
222
  torch_dtype=torch.float16
223
  ).to(device).eval()
224
 
225
+ def downsample_video(video_path):
226
+ """
227
+ Downsamples the video to evenly spaced frames.
228
+ Each frame is returned as a PIL image along with its timestamp.
229
+ """
230
+ vidcap = cv2.VideoCapture(video_path)
231
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
232
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
233
+ frames = []
234
+ # Use a maximum of 10 frames to avoid excessive memory usage
235
+ frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
236
+ for i in frame_indices:
237
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
238
+ success, image = vidcap.read()
239
+ if success:
240
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
241
+ pil_image = Image.fromarray(image)
242
+ timestamp = round(i / fps, 2)
243
+ frames.append((pil_image, timestamp))
244
+ vidcap.release()
245
+ return frames
246
 
247
  @spaces.GPU
248
  def generate_image(model_name: str, text: str, image: Image.Image,
 
286
  ]
287
  }]
288
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
289
+
290
+ # FIX: Set truncation to False to avoid the ValueError
 
291
  inputs = processor(
292
  text=[prompt_full],
293
  images=[image],
 
306
  time.sleep(0.01)
307
  yield buffer, buffer
308
 
309
+ @spaces.GPU
310
+ def generate_video(model_name: str, text: str, video_path: str,
311
+ max_new_tokens: int = 1024,
312
+ temperature: float = 0.6,
313
+ top_p: float = 0.9,
314
+ top_k: int = 50,
315
+ repetition_penalty: float = 1.2):
316
+ """
317
+ Generates responses using the selected model for video input.
318
+ Yields raw text and Markdown-formatted text.
319
+ """
320
+ if model_name == "RolmOCR-7B":
321
+ processor = processor_m
322
+ model = model_m
323
+ elif model_name == "Qwen2-VL-OCR-2B":
324
+ processor = processor_x
325
+ model = model_x
326
+ elif model_name == "Nanonets-OCR2-3B":
327
+ processor = processor_v
328
+ model = model_v
329
+ elif model_name == "Aya-Vision-8B":
330
+ processor = processor_a
331
+ model = model_a
332
+ elif model_name == "olmOCR-7B-0725":
333
+ processor = processor_w
334
+ model = model_w
335
+ else:
336
+ yield "Invalid model selected.", "Invalid model selected."
337
+ return
338
+
339
+ if video_path is None:
340
+ yield "Please upload a video.", "Please upload a video."
341
+ return
342
+
343
+ frames_with_ts = downsample_video(video_path)
344
+ images_for_processor = [frame for frame, ts in frames_with_ts]
345
+
346
+ messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
347
+ for frame in images_for_processor:
348
+ messages[0]["content"].insert(0, {"type": "image"})
349
+
350
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
351
+
352
+ inputs = processor(
353
+ text=[prompt_full],
354
+ images=images_for_processor,
355
+ return_tensors="pt",
356
+ padding=True
357
+ ).to(device)
358
+
359
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
360
+ generation_kwargs = {
361
+ **inputs,
362
+ "streamer": streamer,
363
+ "max_new_tokens": max_new_tokens,
364
+ "do_sample": True,
365
+ "temperature": temperature,
366
+ "top_p": top_p,
367
+ "top_k": top_k,
368
+ "repetition_penalty": repetition_penalty,
369
+ }
370
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
371
+ thread.start()
372
+ buffer = ""
373
+ for new_text in streamer:
374
+ buffer += new_text
375
+ buffer = buffer.replace("<|im_end|>", "")
376
+ time.sleep(0.01)
377
+ yield buffer, buffer
378
 
379
+ # Define examples for image and video inference
380
  image_examples = [
381
+ ["Extract the full page.", "images/ocr.png"],
382
+ ["Extract the content.", "images/4.png"],
383
  ["Convert this page to doc [table] precisely for markdown.", "images/0.png"]
384
  ]
385
 
386
+ video_examples = [
387
+ ["Explain the Ad in Detail.", "videos/1.mp4"],
388
+ ]
389
 
390
  # Create the Gradio Interface
391
+ with gr.Blocks(css=css, theme=thistle_theme) as demo:
392
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
393
  with gr.Row():
394
  with gr.Column(scale=2):
395
+ with gr.Tabs():
396
+ with gr.TabItem("Image Inference"):
397
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
398
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
399
+ image_submit = gr.Button("Submit", variant="primary")
400
+ gr.Examples(
401
+ examples=image_examples,
402
+ inputs=[image_query, image_upload]
403
+ )
404
+ with gr.TabItem("Video Inference"):
405
+ video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
406
+ video_upload = gr.Video(label="Upload Video", height=290)
407
+ video_submit = gr.Button("Submit", variant="primary")
408
+ gr.Examples(
409
+ examples=video_examples,
410
+ inputs=[video_query, video_upload]
411
+ )
412
+ gr.Markdown("> Only the olmOCR and RolmOCR models currently support video inference (max video length: 30 secs).")
413
  with gr.Accordion("Advanced options", open=False):
414
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
 
415
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
416
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
417
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
418
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
419
+
 
420
  with gr.Column(scale=3):
421
+ gr.Markdown("## Output", elem_id="output-title")
422
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
423
+ with gr.Accordion("(Result.md)", open=False):
424
+ markdown_output = gr.Markdown(label="(Result.Md)")
425
+
426
+ model_choice = gr.Radio(
427
+ choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
428
  "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
429
+ label="Select Model",
430
+ value="Nanonets-OCR2-3B"
431
+ )
432
+
433
  image_submit.click(
434
  fn=generate_image,
435
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
436
+ outputs=[output, markdown_output]
437
+ )
438
+ video_submit.click(
439
+ fn=generate_video,
440
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
441
  outputs=[output, markdown_output]
442
  )
443