Spaces:
Running
on
Zero
Running
on
Zero
update app
Browse files
app.py
CHANGED
|
@@ -4,11 +4,11 @@ import uuid
|
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import asyncio
|
|
|
|
| 7 |
from threading import Thread
|
| 8 |
from pathlib import Path
|
| 9 |
from io import BytesIO
|
| 10 |
from typing import Optional, Tuple, Dict, Any, Iterable
|
| 11 |
-
import re
|
| 12 |
|
| 13 |
import gradio as gr
|
| 14 |
import spaces
|
|
@@ -30,6 +30,9 @@ from transformers.image_utils import load_image
|
|
| 30 |
from gradio.themes import Soft
|
| 31 |
from gradio.themes.utils import colors, fonts, sizes
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
colors.orange_red = colors.Color(
|
| 34 |
name="orange_red",
|
| 35 |
c50="#FFF0E5",
|
|
@@ -37,7 +40,7 @@ colors.orange_red = colors.Color(
|
|
| 37 |
c200="#FFC299",
|
| 38 |
c300="#FFA366",
|
| 39 |
c400="#FF8533",
|
| 40 |
-
c500="#FF4500",
|
| 41 |
c600="#E63E00",
|
| 42 |
c700="#CC3700",
|
| 43 |
c800="#B33000",
|
|
@@ -96,6 +99,7 @@ class OrangeRedTheme(Soft):
|
|
| 96 |
block_label_background_fill="*primary_200",
|
| 97 |
)
|
| 98 |
|
|
|
|
| 99 |
orange_red_theme = OrangeRedTheme()
|
| 100 |
|
| 101 |
css = """
|
|
@@ -173,6 +177,59 @@ model_q3vl = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
|
| 173 |
dtype=torch.float16
|
| 174 |
).to(device).eval()
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
def extract_gif_frames(gif_path: str):
|
| 177 |
if not gif_path:
|
| 178 |
return []
|
|
@@ -249,74 +306,6 @@ def navigate_pdf_page(direction: str, state: Dict[str, Any]):
|
|
| 249 |
page_info_html = f'<div style="text-align:center;">Page {new_index + 1} / {total_pages}</div>'
|
| 250 |
return image_preview, state, page_info_html
|
| 251 |
|
| 252 |
-
def draw_boxes_on_image(image: Image.Image, text_output: str, object_name: str) -> Tuple[Image.Image, str]:
|
| 253 |
-
try:
|
| 254 |
-
# Extract the JSON part of the text output
|
| 255 |
-
match = re.search(r'\[\s*\[.*?\]\s*\]', text_output, re.DOTALL)
|
| 256 |
-
if not match:
|
| 257 |
-
return image, f"Could not find coordinates in the model output: {text_output}"
|
| 258 |
-
|
| 259 |
-
boxes_str = match.group(0)
|
| 260 |
-
boxes = json.loads(boxes_str)
|
| 261 |
-
|
| 262 |
-
if not boxes or not isinstance(boxes[0], list):
|
| 263 |
-
return image, f"No valid boxes found in parsed data: {boxes}"
|
| 264 |
-
|
| 265 |
-
width, height = image.size
|
| 266 |
-
np_image = np.array(image.convert("RGB"))
|
| 267 |
-
|
| 268 |
-
# Denormalize coordinates
|
| 269 |
-
xyxy = []
|
| 270 |
-
for box in boxes:
|
| 271 |
-
x1, y1, x2, y2 = box
|
| 272 |
-
xyxy.append([x1 * width, y1 * height, x2 * width, y2 * height])
|
| 273 |
-
|
| 274 |
-
detections = sv.Detections(xyxy=np.array(xyxy))
|
| 275 |
-
|
| 276 |
-
bounding_box_annotator = sv.BoxAnnotator(thickness=2)
|
| 277 |
-
label_annotator = sv.LabelAnnotator(text_thickness=1, text_scale=0.5)
|
| 278 |
-
|
| 279 |
-
labels = [f"{object_name} #{i+1}" for i in range(len(detections))]
|
| 280 |
-
|
| 281 |
-
annotated_image = bounding_box_annotator.annotate(scene=np_image.copy(), detections=detections)
|
| 282 |
-
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
|
| 283 |
-
|
| 284 |
-
return Image.fromarray(annotated_image), text_output
|
| 285 |
-
except (json.JSONDecodeError, IndexError, TypeError) as e:
|
| 286 |
-
return image, f"Failed to parse or draw boxes. Error: {e}\nModel Output:\n{text_output}"
|
| 287 |
-
|
| 288 |
-
def draw_points_on_image(image: Image.Image, text_output: str) -> Tuple[Image.Image, str]:
|
| 289 |
-
try:
|
| 290 |
-
match = re.search(r'\[\s*\[.*?\]\s*\]', text_output, re.DOTALL)
|
| 291 |
-
if not match:
|
| 292 |
-
return image, f"Could not find coordinates in the model output: {text_output}"
|
| 293 |
-
|
| 294 |
-
points_str = match.group(0)
|
| 295 |
-
points = json.loads(points_str)
|
| 296 |
-
|
| 297 |
-
if not points or not isinstance(points[0], list):
|
| 298 |
-
return image, f"No valid points found in parsed data: {points}"
|
| 299 |
-
|
| 300 |
-
width, height = image.size
|
| 301 |
-
np_image = np.array(image.convert("RGB"))
|
| 302 |
-
|
| 303 |
-
# Denormalize coordinates
|
| 304 |
-
xy = []
|
| 305 |
-
for point in points:
|
| 306 |
-
x, y = point
|
| 307 |
-
xy.append([x * width, y * height])
|
| 308 |
-
|
| 309 |
-
points_array = np.array(xy).reshape(1, -1, 2)
|
| 310 |
-
key_points = sv.KeyPoints(xy=points_array)
|
| 311 |
-
|
| 312 |
-
point_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.RED)
|
| 313 |
-
annotated_image = point_annotator.annotate(scene=np_image.copy(), key_points=key_points)
|
| 314 |
-
|
| 315 |
-
return Image.fromarray(annotated_image), text_output
|
| 316 |
-
except (json.JSONDecodeError, IndexError, TypeError) as e:
|
| 317 |
-
return image, f"Failed to parse or draw points. Error: {e}\nModel Output:\n{text_output}"
|
| 318 |
-
|
| 319 |
-
|
| 320 |
@spaces.GPU
|
| 321 |
def generate_image(text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
|
| 322 |
if image is None:
|
|
@@ -437,62 +426,55 @@ def generate_gif(text: str, gif_path: str, max_new_tokens: int = 1024, temperatu
|
|
| 437 |
yield buffer, buffer
|
| 438 |
|
| 439 |
@spaces.GPU
|
| 440 |
-
def
|
|
|
|
|
|
|
|
|
|
| 441 |
if image is None:
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 457 |
inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
|
| 458 |
-
|
| 459 |
-
# This task is not streamed because we need the full output to parse and draw boxes
|
| 460 |
-
outputs = model_q3vl.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
| 461 |
-
response_text = processor_q3vl.decode(outputs[0], skip_special_tokens=True).strip()
|
| 462 |
-
|
| 463 |
-
# Extract only the user-facing part of the response
|
| 464 |
-
final_text = response_text.split('<|im_end|>')[-1].strip() if '<|im_end|>' in response_text else response_text
|
| 465 |
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
yield image, "Please upload an image."
|
| 473 |
-
return
|
| 474 |
-
if not text:
|
| 475 |
-
yield image, "Please enter the object/point name to detect."
|
| 476 |
-
return
|
| 477 |
-
|
| 478 |
-
prompt = (
|
| 479 |
-
f"You are an expert point detection model. Your task is to find the specific location of '{text}' in the image. "
|
| 480 |
-
"You must respond ONLY with a JSON list containing a single coordinate pair. The coordinate must be in the format "
|
| 481 |
-
"[[x, y]], where the coordinates are normalized to be between 0 and 1. "
|
| 482 |
-
"Do not provide any other text, explanation, or preamble. For example: [[0.45, 0.67]]"
|
| 483 |
-
)
|
| 484 |
-
|
| 485 |
-
messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
|
| 486 |
-
prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 487 |
-
inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
|
| 488 |
-
|
| 489 |
-
outputs = model_q3vl.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
| 490 |
-
response_text = processor_q3vl.decode(outputs[0], skip_special_tokens=True).strip()
|
| 491 |
-
|
| 492 |
-
final_text = response_text.split('<|im_end|>')[-1].strip() if '<|im_end|>' in response_text else response_text
|
| 493 |
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
|
| 498 |
image_examples = [["Perform OCR on the image...", "examples/images/1.jpg"],
|
|
@@ -506,10 +488,9 @@ gif_examples = [["Describe this GIF.", "examples/gifs/1.gif"],
|
|
| 506 |
["Describe this GIF.", "examples/gifs/2.gif"]]
|
| 507 |
caption_examples = [["examples/captions/1.JPG"],
|
| 508 |
["examples/captions/2.jpeg"], ["examples/captions/3.jpeg"]]
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
["the clock on the wall", "examples/detection/room.jpg"]]
|
| 513 |
|
| 514 |
|
| 515 |
with gr.Blocks(theme=orange_red_theme, css=css) as demo:
|
|
@@ -524,17 +505,11 @@ with gr.Blocks(theme=orange_red_theme, css=css) as demo:
|
|
| 524 |
image_submit = gr.Button("Submit", variant="primary")
|
| 525 |
gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
|
| 526 |
|
| 527 |
-
with gr.TabItem("
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
gr.Examples(examples=
|
| 532 |
-
|
| 533 |
-
with gr.TabItem("Point Detection"):
|
| 534 |
-
point_det_query = gr.Textbox(label="Point to Detect", placeholder="e.g., 'the cat's left eye'")
|
| 535 |
-
point_det_upload = gr.Image(type="pil", label="Upload Image", height=290)
|
| 536 |
-
point_det_submit = gr.Button("Detect Point", variant="primary")
|
| 537 |
-
gr.Examples(examples=point_detection_examples, inputs=[point_det_query, point_det_upload])
|
| 538 |
|
| 539 |
with gr.TabItem("PDF Inference"):
|
| 540 |
with gr.Row():
|
|
@@ -555,17 +530,33 @@ with gr.Blocks(theme=orange_red_theme, css=css) as demo:
|
|
| 555 |
gif_upload = gr.Image(type="filepath", label="Upload GIF", height=290)
|
| 556 |
gif_submit = gr.Button("Submit", variant="primary")
|
| 557 |
gr.Examples(examples=gif_examples, inputs=[gif_query, gif_upload])
|
| 558 |
-
|
| 559 |
with gr.TabItem("Caption"):
|
| 560 |
caption_image_upload = gr.Image(type="pil", label="Image to Caption", height=290)
|
| 561 |
caption_submit = gr.Button("Generate Caption", variant="primary")
|
| 562 |
gr.Examples(examples=caption_examples, inputs=[caption_image_upload])
|
| 563 |
-
|
| 564 |
-
with gr.TabItem("
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
with gr.Accordion("Advanced options", open=False):
|
| 571 |
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
|
|
@@ -576,70 +567,39 @@ with gr.Blocks(theme=orange_red_theme, css=css) as demo:
|
|
| 576 |
|
| 577 |
with gr.Column(scale=3):
|
| 578 |
gr.Markdown("## Output", elem_id="output-title")
|
| 579 |
-
output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=
|
| 580 |
-
|
|
|
|
| 581 |
{"left": "$$", "right": "$$", "display": True},
|
| 582 |
{"left": "$", "right": "$", "display": False}
|
| 583 |
-
]
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
outputs=[output, markdown_output])
|
| 599 |
-
video_submit.click(fn=generate_video,
|
| 600 |
-
inputs=[video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
| 601 |
-
outputs=[output, markdown_output])
|
| 602 |
-
pdf_submit.click(fn=generate_pdf,
|
| 603 |
-
inputs=[pdf_query, pdf_state, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
| 604 |
-
outputs=[output, markdown_output])
|
| 605 |
-
gif_submit.click(fn=generate_gif,
|
| 606 |
-
inputs=[gif_query, gif_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
| 607 |
-
outputs=[output, markdown_output])
|
| 608 |
-
caption_submit.click(fn=generate_caption,
|
| 609 |
-
inputs=[caption_image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
| 610 |
-
outputs=[output, markdown_output])
|
| 611 |
-
|
| 612 |
obj_det_submit.click(
|
| 613 |
-
fn=
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
markdown_output: gr.update(visible=False)
|
| 618 |
-
},
|
| 619 |
-
outputs=[annotated_image_output, raw_detection_output, output, markdown_output]
|
| 620 |
-
).then(
|
| 621 |
-
fn=generate_object_detection,
|
| 622 |
-
inputs=[obj_det_upload, obj_det_query, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
| 623 |
-
outputs=[annotated_image_output, raw_detection_output]
|
| 624 |
)
|
| 625 |
-
|
| 626 |
point_det_submit.click(
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
markdown_output: gr.update(visible=False)
|
| 632 |
-
},
|
| 633 |
-
outputs=[annotated_image_output, raw_detection_output, output, markdown_output]
|
| 634 |
-
).then(
|
| 635 |
-
fn=generate_point_detection,
|
| 636 |
-
inputs=[point_det_upload, point_det_query, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
| 637 |
-
outputs=[annotated_image_output, raw_detection_output]
|
| 638 |
)
|
| 639 |
|
| 640 |
-
pdf_upload.change(fn=load_and_preview_pdf, inputs=[pdf_upload], outputs=[pdf_preview_img, pdf_state, page_info])
|
| 641 |
-
prev_page_btn.click(fn=lambda s: navigate_pdf_page("prev", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
|
| 642 |
-
next_page_btn.click(fn=lambda s: navigate_pdf_page("next", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
|
| 643 |
|
| 644 |
if __name__ == "__main__":
|
| 645 |
demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)
|
|
|
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import asyncio
|
| 7 |
+
import re
|
| 8 |
from threading import Thread
|
| 9 |
from pathlib import Path
|
| 10 |
from io import BytesIO
|
| 11 |
from typing import Optional, Tuple, Dict, Any, Iterable
|
|
|
|
| 12 |
|
| 13 |
import gradio as gr
|
| 14 |
import spaces
|
|
|
|
| 30 |
from gradio.themes import Soft
|
| 31 |
from gradio.themes.utils import colors, fonts, sizes
|
| 32 |
|
| 33 |
+
# --- Theme and CSS Definition ---
|
| 34 |
+
|
| 35 |
+
# Define the new OrangeRed color palette
|
| 36 |
colors.orange_red = colors.Color(
|
| 37 |
name="orange_red",
|
| 38 |
c50="#FFF0E5",
|
|
|
|
| 40 |
c200="#FFC299",
|
| 41 |
c300="#FFA366",
|
| 42 |
c400="#FF8533",
|
| 43 |
+
c500="#FF4500", # OrangeRed base color
|
| 44 |
c600="#E63E00",
|
| 45 |
c700="#CC3700",
|
| 46 |
c800="#B33000",
|
|
|
|
| 99 |
block_label_background_fill="*primary_200",
|
| 100 |
)
|
| 101 |
|
| 102 |
+
# Instantiate the new theme
|
| 103 |
orange_red_theme = OrangeRedTheme()
|
| 104 |
|
| 105 |
css = """
|
|
|
|
| 177 |
dtype=torch.float16
|
| 178 |
).to(device).eval()
|
| 179 |
|
| 180 |
+
# --- Utility functions for Detection and Drawing ---
|
| 181 |
+
|
| 182 |
+
def parse_detection_output(text: str) -> list:
|
| 183 |
+
"""Parses the model's text output to extract bounding boxes or points."""
|
| 184 |
+
match = re.search(r'\[\s*\[.*?\]\s*\]', text)
|
| 185 |
+
if not match:
|
| 186 |
+
return []
|
| 187 |
+
try:
|
| 188 |
+
result = json.loads(match.group(0))
|
| 189 |
+
if isinstance(result, list) and all(isinstance(item, list) for item in result):
|
| 190 |
+
return result
|
| 191 |
+
return []
|
| 192 |
+
except (json.JSONDecodeError, TypeError):
|
| 193 |
+
return []
|
| 194 |
+
|
| 195 |
+
def draw_object_detections(image: Image.Image, detections: list, labels: list) -> Image.Image:
|
| 196 |
+
"""Draws bounding boxes on the image."""
|
| 197 |
+
image_np = np.array(image.convert("RGB"))
|
| 198 |
+
h, w, _ = image_np.shape
|
| 199 |
+
boxes = []
|
| 200 |
+
for box in detections:
|
| 201 |
+
if len(box) == 4:
|
| 202 |
+
x1, y1, x2, y2 = box
|
| 203 |
+
boxes.append([x1 * w, y1 * h, x2 * w, y2 * h])
|
| 204 |
+
if not boxes:
|
| 205 |
+
return image
|
| 206 |
+
detections_sv = sv.Detections(xyxy=np.array(boxes))
|
| 207 |
+
bounding_box_annotator = sv.BoxAnnotator(thickness=2)
|
| 208 |
+
label_annotator = sv.LabelAnnotator(text_thickness=1, text_scale=0.5)
|
| 209 |
+
annotated_image = bounding_box_annotator.annotate(scene=image_np.copy(), detections=detections_sv)
|
| 210 |
+
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections_sv, labels=labels)
|
| 211 |
+
return Image.fromarray(annotated_image)
|
| 212 |
+
|
| 213 |
+
def draw_point_detections(image: Image.Image, points: list) -> Image.Image:
|
| 214 |
+
"""Draws points on the image."""
|
| 215 |
+
image_np = np.array(image.convert("RGB"))
|
| 216 |
+
h, w, _ = image_np.shape
|
| 217 |
+
pts = []
|
| 218 |
+
for point in points:
|
| 219 |
+
if len(point) == 2:
|
| 220 |
+
x, y = point
|
| 221 |
+
pts.append([x * w, y * h])
|
| 222 |
+
if not pts:
|
| 223 |
+
return image
|
| 224 |
+
points_np = np.array(pts).reshape(1, -1, 2)
|
| 225 |
+
key_points = sv.KeyPoints(xy=points_np)
|
| 226 |
+
point_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.RED)
|
| 227 |
+
annotated_image = point_annotator.annotate(scene=image_np.copy(), key_points=key_points)
|
| 228 |
+
return Image.fromarray(annotated_image)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# --- Core Generation Functions ---
|
| 232 |
+
|
| 233 |
def extract_gif_frames(gif_path: str):
|
| 234 |
if not gif_path:
|
| 235 |
return []
|
|
|
|
| 306 |
page_info_html = f'<div style="text-align:center;">Page {new_index + 1} / {total_pages}</div>'
|
| 307 |
return image_preview, state, page_info_html
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
@spaces.GPU
|
| 310 |
def generate_image(text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
|
| 311 |
if image is None:
|
|
|
|
| 426 |
yield buffer, buffer
|
| 427 |
|
| 428 |
@spaces.GPU
|
| 429 |
+
def generate_detection(
|
| 430 |
+
image: Image.Image, user_prompt: str, task_type: str, max_new_tokens: int = 256,
|
| 431 |
+
temperature: float = 0.1, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2
|
| 432 |
+
):
|
| 433 |
if image is None:
|
| 434 |
+
return None, "Please upload an image."
|
| 435 |
+
if not user_prompt:
|
| 436 |
+
return image, "Please provide a prompt describing what to detect."
|
| 437 |
+
|
| 438 |
+
if task_type == "Object Detection":
|
| 439 |
+
system_prompt = (
|
| 440 |
+
f"You are an expert object detector. Find all instances of '{user_prompt}' in the image. "
|
| 441 |
+
"Respond ONLY with a Python list of bounding boxes in the format [[x_min, y_min, x_max, y_max], ...]. "
|
| 442 |
+
"The coordinates must be normalized between 0.0 and 1.0."
|
| 443 |
+
)
|
| 444 |
+
elif task_type == "Point Detection":
|
| 445 |
+
system_prompt = (
|
| 446 |
+
f"You are an expert keypoint detector. Find the specific points for '{user_prompt}' in the image. "
|
| 447 |
+
"Respond ONLY with a Python list of points in the format [[x, y], ...]. "
|
| 448 |
+
"The coordinates must be normalized between 0.0 and 1.0."
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
return image, "Invalid task type specified."
|
| 452 |
+
|
| 453 |
+
messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": system_prompt}]}]
|
| 454 |
prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 455 |
inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
+
generation_kwargs = {
|
| 458 |
+
**inputs, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature,
|
| 459 |
+
"top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty,
|
| 460 |
+
}
|
| 461 |
|
| 462 |
+
generate_ids = model_q3vl.generate(**generation_kwargs)
|
| 463 |
+
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
|
| 464 |
+
response_text = processor_q3vl.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
|
| 466 |
+
try:
|
| 467 |
+
coords = parse_detection_output(response_text)
|
| 468 |
+
if not coords:
|
| 469 |
+
return image, f"Could not detect '{user_prompt}'.\nModel raw output:\n{response_text}"
|
| 470 |
+
if task_type == "Object Detection":
|
| 471 |
+
labels = [f"{user_prompt} #{i+1}" for i in range(len(coords))]
|
| 472 |
+
annotated_image = draw_object_detections(image, coords, labels)
|
| 473 |
+
else: # Point Detection
|
| 474 |
+
annotated_image = draw_point_detections(image, coords)
|
| 475 |
+
return annotated_image, response_text
|
| 476 |
+
except Exception as e:
|
| 477 |
+
return image, f"An error occurred during processing:\n{str(e)}\n\nModel raw output:\n{response_text}"
|
| 478 |
|
| 479 |
|
| 480 |
image_examples = [["Perform OCR on the image...", "examples/images/1.jpg"],
|
|
|
|
| 488 |
["Describe this GIF.", "examples/gifs/2.gif"]]
|
| 489 |
caption_examples = [["examples/captions/1.JPG"],
|
| 490 |
["examples/captions/2.jpeg"], ["examples/captions/3.jpeg"]]
|
| 491 |
+
# NOTE: You'll need to create these example image files in a directory named 'examples/detection/'
|
| 492 |
+
obj_det_examples = [["examples/detection/obj1.jpg", "the two people"], ["examples/detection/obj2.jpg", "the yellow taxi"]]
|
| 493 |
+
point_det_examples = [["examples/detection/point1.jpg", "the eyes of the person"], ["examples/detection/point2.jpg", "the headlights of the car"]]
|
|
|
|
| 494 |
|
| 495 |
|
| 496 |
with gr.Blocks(theme=orange_red_theme, css=css) as demo:
|
|
|
|
| 505 |
image_submit = gr.Button("Submit", variant="primary")
|
| 506 |
gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
|
| 507 |
|
| 508 |
+
with gr.TabItem("Video Inference"):
|
| 509 |
+
video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
|
| 510 |
+
video_upload = gr.Video(label="Upload Video(≤30s)", height=290)
|
| 511 |
+
video_submit = gr.Button("Submit", variant="primary")
|
| 512 |
+
gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
with gr.TabItem("PDF Inference"):
|
| 515 |
with gr.Row():
|
|
|
|
| 530 |
gif_upload = gr.Image(type="filepath", label="Upload GIF", height=290)
|
| 531 |
gif_submit = gr.Button("Submit", variant="primary")
|
| 532 |
gr.Examples(examples=gif_examples, inputs=[gif_query, gif_upload])
|
| 533 |
+
|
| 534 |
with gr.TabItem("Caption"):
|
| 535 |
caption_image_upload = gr.Image(type="pil", label="Image to Caption", height=290)
|
| 536 |
caption_submit = gr.Button("Generate Caption", variant="primary")
|
| 537 |
gr.Examples(examples=caption_examples, inputs=[caption_image_upload])
|
| 538 |
+
|
| 539 |
+
with gr.TabItem("Object Detection"):
|
| 540 |
+
with gr.Row():
|
| 541 |
+
with gr.Column(scale=1):
|
| 542 |
+
obj_det_image_upload = gr.Image(type="pil", label="Upload Image", height=290)
|
| 543 |
+
obj_det_query = gr.Textbox(label="Object to Detect", placeholder="e.g., car, person, dog")
|
| 544 |
+
obj_det_submit = gr.Button("Detect Objects", variant="primary")
|
| 545 |
+
with gr.Column(scale=1):
|
| 546 |
+
obj_det_output_image = gr.Image(type="pil", label="Detection Result", height=290)
|
| 547 |
+
obj_det_output_text = gr.Textbox(label="Model Raw Output", interactive=False, lines=5)
|
| 548 |
+
gr.Examples(examples=obj_det_examples, inputs=[obj_det_image_upload, obj_det_query])
|
| 549 |
+
|
| 550 |
+
with gr.TabItem("Point Detection"):
|
| 551 |
+
with gr.Row():
|
| 552 |
+
with gr.Column(scale=1):
|
| 553 |
+
point_det_image_upload = gr.Image(type="pil", label="Upload Image", height=290)
|
| 554 |
+
point_det_query = gr.Textbox(label="Point(s) to Detect", placeholder="e.g., the eyes of the cat")
|
| 555 |
+
point_det_submit = gr.Button("Detect Points", variant="primary")
|
| 556 |
+
with gr.Column(scale=1):
|
| 557 |
+
point_det_output_image = gr.Image(type="pil", label="Detection Result", height=290)
|
| 558 |
+
point_det_output_text = gr.Textbox(label="Model Raw Output", interactive=False, lines=5)
|
| 559 |
+
gr.Examples(examples=point_det_examples, inputs=[point_det_image_upload, point_det_query])
|
| 560 |
|
| 561 |
with gr.Accordion("Advanced options", open=False):
|
| 562 |
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
|
|
|
|
| 567 |
|
| 568 |
with gr.Column(scale=3):
|
| 569 |
gr.Markdown("## Output", elem_id="output-title")
|
| 570 |
+
output = gr.Textbox(label="Raw Output Stream (General Tasks)", interactive=False, lines=20, show_copy_button=True)
|
| 571 |
+
with gr.Accordion("(Result.md)", open=False):
|
| 572 |
+
markdown_output = gr.Markdown(label="(Result.Md)", latex_delimiters=[
|
| 573 |
{"left": "$$", "right": "$$", "display": True},
|
| 574 |
{"left": "$", "right": "$", "display": False}
|
| 575 |
+
])
|
| 576 |
+
|
| 577 |
+
# Click handlers for original tabs
|
| 578 |
+
image_submit.click(fn=generate_image, inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
|
| 579 |
+
video_submit.click(fn=generate_video, inputs=[video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
|
| 580 |
+
pdf_submit.click(fn=generate_pdf, inputs=[pdf_query, pdf_state, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
|
| 581 |
+
gif_submit.click(fn=generate_gif, inputs=[gif_query, gif_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
|
| 582 |
+
caption_submit.click(fn=generate_caption, inputs=[caption_image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[output, markdown_output])
|
| 583 |
+
|
| 584 |
+
# PDF navigation handlers
|
| 585 |
+
pdf_upload.change(fn=load_and_preview_pdf, inputs=[pdf_upload], outputs=[pdf_preview_img, pdf_state, page_info])
|
| 586 |
+
prev_page_btn.click(fn=lambda s: navigate_pdf_page("prev", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
|
| 587 |
+
next_page_btn.click(fn=lambda s: navigate_pdf_page("next", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
|
| 588 |
+
|
| 589 |
+
# Click handlers for NEW tabs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
obj_det_submit.click(
|
| 591 |
+
fn=generate_detection,
|
| 592 |
+
inputs=[obj_det_image_upload, obj_det_query, gr.Textbox(value="Object Detection", visible=False),
|
| 593 |
+
max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
| 594 |
+
outputs=[obj_det_output_image, obj_det_output_text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
)
|
|
|
|
| 596 |
point_det_submit.click(
|
| 597 |
+
fn=generate_detection,
|
| 598 |
+
inputs=[point_det_image_upload, point_det_query, gr.Textbox(value="Point Detection", visible=False),
|
| 599 |
+
max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
| 600 |
+
outputs=[point_det_output_image, point_det_output_text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
)
|
| 602 |
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
if __name__ == "__main__":
|
| 605 |
demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)
|