Spaces:
Runtime error
Runtime error
| # app.py — UI-TARS demo (OSS disabled) | |
| import base64 | |
| import json | |
| import ast | |
| import os | |
| import re | |
| import io | |
| import math | |
| from datetime import datetime | |
| import gradio as gr | |
| from PIL import ImageDraw | |
| # ========================= | |
| # OpenAI client (optional) | |
| # ========================= | |
| # If OPENAI_API_KEY is set we will use OpenAI for parsing the model output text. | |
| # If ENDPOINT_URL is set, we'll point the OpenAI client at that base URL (advanced use). | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| ENDPOINT_URL = os.getenv("ENDPOINT_URL") # optional | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") # safe default instead of "tgi" | |
| client = None | |
| if OPENAI_API_KEY: | |
| try: | |
| from openai import OpenAI | |
| if ENDPOINT_URL: | |
| client = OpenAI(api_key=OPENAI_API_KEY, base_url=ENDPOINT_URL) | |
| else: | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| print("✅ OpenAI client initialized.") | |
| except Exception as e: | |
| print(f"⚠️ OpenAI client not available: {e}") | |
| else: | |
| print("ℹ️ OPENAI_API_KEY not set. Running without OpenAI parsing.") | |
| # ========================= | |
| # UI-TARS prompt | |
| # ========================= | |
| DESCRIPTION = "[UI-TARS](https://github.com/bytedance/UI-TARS)" | |
| prompt = ( | |
| "Output only the coordinate of one box in your response. " | |
| "Return a tuple like (x,y) with values in 0..1000 for x and y. " | |
| "Do not include any extra text. " | |
| ) | |
| # ========================= | |
| # OSS (Aliyun) — DISABLED | |
| # ========================= | |
| # The original demo used Aliyun OSS (oss2) to upload images/metadata. | |
| # We disable it fully so no ENV like BUCKET / ENDPOINT is required. | |
| bucket = None | |
| print("⚠️ OSS integration disabled: skipping Aliyun storage.") | |
| def draw_point_area(image, point): | |
| """Draw a red point+circle at a (0..1000, 0..1000) coordinate on the given PIL image.""" | |
| if not point: | |
| return image | |
| radius = min(image.width, image.height) // 15 | |
| x = round(point[0] / 1000 * image.width) | |
| y = round(point[1] / 1000 * image.height) | |
| drawer = ImageDraw.Draw(image) | |
| drawer.ellipse((x - radius, y - radius, x + radius, y + radius), outline="red", width=2) | |
| drawer.ellipse((x - 2, y - 2, x + 2, y + 2), fill="red") | |
| return image | |
| def resize_image(image): | |
| """Resize extremely large screenshots to keep compute stable.""" | |
| max_pixels = 6000 * 28 * 28 | |
| if image.width * image.height > max_pixels: | |
| max_pixels = 2700 * 28 * 28 | |
| else: | |
| max_pixels = 1340 * 28 * 28 | |
| resize_factor = math.sqrt(max_pixels / (image.width * image.height)) | |
| width, height = int(image.width * resize_factor), int(image.height * resize_factor) | |
| return image.resize((width, height)) | |
| def upload_images(session_id, image, result_image, query): | |
| """No-op when OSS is disabled. Keeps API stable.""" | |
| if bucket is None: | |
| print("↪️ Skipped OSS upload (no bucket configured).") | |
| return | |
| img_path = f"{session_id}.png" | |
| result_img_path = f"{session_id}-draw.png" | |
| metadata = dict( | |
| query=query, | |
| resize_image=img_path, | |
| result_image=result_img_path, | |
| session_id=session_id, | |
| ) | |
| img_bytes = io.BytesIO() | |
| image.save(img_bytes, format="png") | |
| bucket.put_object(img_path, img_bytes.getvalue()) | |
| rst_img_bytes = io.BytesIO() | |
| result_image.save(rst_img_bytes, format="png") | |
| bucket.put_object(result_img_path, rst_img_bytes.getvalue()) | |
| bucket.put_object(f"{session_id}.json", json.dumps(metadata).encode("utf-8")) | |
| print("✅ (would) upload images — skipped unless bucket configured") | |
| def run_ui(image, query, session_id, is_example_image): | |
| """Main inference path: builds the message, asks the model for (x,y), draws, returns results.""" | |
| click_xy = None | |
| images_during_iterations = [] | |
| width, height = image.width, image.height | |
| # Resize for throughput + encode | |
| image = resize_image(image) | |
| buf = io.BytesIO() | |
| image.save(buf, format="png") | |
| base64_image = base64.standard_b64encode(buf.getvalue()).decode("utf-8") | |
| # Prepare prompt for an LLM that returns '(x,y)' | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}, | |
| {"type": "text", "text": prompt + query}, | |
| ], | |
| } | |
| ] | |
| # If OpenAI client is present, ask it to parse coordinates. Otherwise we return a safe default. | |
| output_text = "" | |
| if client is not None: | |
| try: | |
| resp = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=1.0, | |
| top_p=0.7, | |
| max_tokens=128, | |
| frequency_penalty=1, | |
| stream=False, | |
| ) | |
| output_text = resp.choices[0].message.content or "" | |
| except Exception as e: | |
| output_text = "" | |
| print(f"⚠️ OpenAI call failed: {e}") | |
| # Extract "(x,y)" from the text using regex | |
| pattern = r"\((\d+,\s*\d+)\)" | |
| match = re.search(pattern, output_text) | |
| if match: | |
| coordinates = match.group(1) | |
| try: | |
| click_xy = ast.literal_eval(coordinates) # (x, y) with 0..1000 scale | |
| except Exception: | |
| click_xy = None | |
| # If we still don't have coordinates, fall back to center | |
| if click_xy is None: | |
| click_xy = (500, 500) | |
| # Draw result + convert to absolute pixel coords for display | |
| result_image = draw_point_area(image.copy(), click_xy) | |
| images_during_iterations.append(result_image) | |
| abs_xy = (round(click_xy[0] / 1000 * width), round(click_xy[1] / 1000 * height)) | |
| # Upload artifacts only for real (non-example) inputs | |
| if str(is_example_image) == "False": | |
| upload_images(session_id, image, result_image, query) | |
| return images_during_iterations, str(abs_xy) | |
| def update_vote(vote_type, image, click_image, prompt_text, is_example): | |
| """Simple feedback hook (no external upload when OSS disabled).""" | |
| if vote_type == "upvote": | |
| return "Everything good" | |
| if is_example == "True": | |
| return "Do nothing for example" | |
| # Example gallery returns file paths; we do nothing here | |
| return "Thank you for your feedback!" | |
| # Demo examples | |
| examples = [ | |
| ["./examples/solitaire.png", "Play the solitaire collection", True], | |
| ["./examples/weather_ui.png", "Open map", True], | |
| ["./examples/football_live.png", "click team 1 win", True], | |
| ["./examples/windows_panel.png", "switch to documents", True], | |
| ["./examples/paint_3d.png", "rotate left", True], | |
| ["./examples/finder.png", "view files from airdrop", True], | |
| ["./examples/amazon.jpg", "Search bar at the top of the page", True], | |
| ["./examples/semantic.jpg", "Home", True], | |
| ["./examples/accweather.jpg", "Select May", True], | |
| ["./examples/arxiv.jpg", "Home", True], | |
| ["./examples/health.jpg", "text labeled by 2023/11/26", True], | |
| ["./examples/ios_setting.png", "Turn off Do not disturb.", True], | |
| ] | |
| title_markdown = """ | |
| # UI-TARS Pioneering Automated GUI Interaction with Native Agents | |
| [[🤗Model](https://huggingface.co/bytedance-research/UI-TARS-7B-SFT)] [[⌨️Code](https://github.com/bytedance/UI-TARS)] [[📑Paper](https://github.com/bytedance/UI-TARS/blob/main/UI_TARS_paper.pdf)] [🏄[Midscene (Browser Automation)](https://github.com/web-infra-dev/Midscene)] [🫨[Discord](https://discord.gg/txAE43ps)] | |
| """ | |
| tos_markdown = """ | |
| ### Terms of use | |
| This demo is governed by the original license of UI-TARS. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. (注:本演示受UI-TARS的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。) | |
| """ | |
| learn_more_markdown = """ | |
| ### License | |
| Apache License 2.0 | |
| """ | |
| code_adapt_markdown = """ | |
| ### Acknowledgments | |
| The app code is modified from [ShowUI](https://huggingface.co/spaces/showlab/ShowUI) | |
| """ | |
| block_css = """ | |
| #buttons button { min-width: min(120px,100%); } | |
| #chatbot img { | |
| max-width: 80%; | |
| max-height: 80vh; | |
| width: auto; | |
| height: auto; | |
| object-fit: contain; | |
| } | |
| """ | |
| def build_demo(): | |
| with gr.Blocks(title="UI-TARS Demo", theme=gr.themes.Default(), css=block_css) as demo: | |
| state_session_id = gr.State(value=None) | |
| gr.Markdown(title_markdown) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| imagebox = gr.Image(type="pil", label="Input Screenshot") | |
| textbox = gr.Textbox( | |
| show_label=True, | |
| placeholder="Enter an instruction and press Submit", | |
| label="Instruction", | |
| ) | |
| submit_btn = gr.Button(value="Submit", variant="primary") | |
| with gr.Column(scale=6): | |
| output_gallery = gr.Gallery(label="Output with click", object_fit="contain", preview=True) | |
| gr.HTML( | |
| """ | |
| <p><strong>Notice:</strong> The <span style="color: red;">red point</span> with a circle on the output image represents the predicted coordinates for a click.</p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| output_coords = gr.Textbox(label="Final Coordinates") | |
| image_size = gr.Textbox(label="Image Size") | |
| gr.HTML("<p><strong>Expected result or not? help us improve! ⬇️</strong></p>") | |
| with gr.Row(elem_id="action-buttons", equal_height=True): | |
| upvote_btn = gr.Button(value="👍 Looks good!", variant="secondary") | |
| downvote_btn = gr.Button(value="👎 Wrong coordinates!", variant="secondary") | |
| clear_btn = gr.Button(value="🗑️ Clear", interactive=True) | |
| with gr.Column(scale=3): | |
| gr.Examples( | |
| examples=[[e[0], e[1]] for e in examples], | |
| inputs=[imagebox, textbox], | |
| outputs=[textbox], | |
| examples_per_page=3, | |
| ) | |
| is_example_dropdown = gr.Dropdown( | |
| choices=["True", "False"], value="False", visible=False, label="Is Example Image", | |
| ) | |
| def set_is_example(query): | |
| for _, example_query, is_example in examples: | |
| if query.strip() == example_query.strip(): | |
| return str(is_example) | |
| return "False" | |
| textbox.change(set_is_example, inputs=[textbox], outputs=[is_example_dropdown]) | |
| def on_submit(image, query, is_example_image): | |
| if image is None: | |
| raise ValueError("No image provided. Please upload an image before submitting.") | |
| session_id = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| images_during_iterations, click_coords = run_ui(image, query, session_id, is_example_image) | |
| return images_during_iterations, click_coords, session_id, f"{image.width}x{image.height}" | |
| submit_btn.click( | |
| on_submit, | |
| [imagebox, textbox, is_example_dropdown], | |
| [output_gallery, output_coords, state_session_id, image_size], | |
| ) | |
| clear_btn.click( | |
| lambda: (None, None, None, None, None, None), | |
| inputs=None, | |
| outputs=[imagebox, textbox, output_gallery, output_coords, state_session_id, image_size], | |
| queue=False, | |
| ) | |
| upvote_btn.click( | |
| lambda image, click_image, prompt_text, is_example: | |
| update_vote("upvote", image, click_image, prompt_text, is_example), | |
| inputs=[imagebox, output_gallery, textbox, is_example_dropdown], | |
| outputs=[], | |
| queue=False, | |
| ) | |
| downvote_btn.click( | |
| lambda image, click_image, prompt_text, is_example: | |
| update_vote("downvote", image, click_image, prompt_text, is_example), | |
| inputs=[imagebox, output_gallery, textbox, is_example_dropdown], | |
| outputs=[], | |
| queue=False, | |
| ) | |
| gr.Markdown(tos_markdown) | |
| gr.Markdown(learn_more_markdown) | |
| gr.Markdown(code_adapt_markdown) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.queue(api_open=False).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True, | |
| ) | |