sharathmajjigi commited on
Commit
efd12df
Β·
1 Parent(s): 7d18df7

Implement proper UI-TARS grounding model with Qwen2.5-VL architecture

Browse files
Files changed (2) hide show
  1. app.py +128 -24
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,45 +1,149 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from PIL import Image
5
  import io
6
  import base64
7
  import json
 
8
 
9
- # Load the UI-TARS model (this will download ~7GB on first run)
10
  model_name = "ByteDance-Seed/UI-TARS-1.5-7B"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def process_grounding(image, prompt):
15
  """
16
  Process image with UI-TARS grounding model
17
- This is a simplified implementation - you'll need to adapt it
18
  """
19
  try:
 
 
 
 
 
 
20
  # Convert image to PIL if needed
21
  if isinstance(image, str):
22
- # Handle base64 string
23
  image_data = base64.b64decode(image)
24
  image = Image.open(io.BytesIO(image_data))
25
 
26
- # Here you would implement the actual UI-TARS grounding logic
27
- # For now, returning a mock response
28
- result = {
29
- "elements": [
30
- {"type": "button", "x": 100, "y": 200, "text": "Click me"},
31
- {"type": "text_field", "x": 150, "y": 300, "text": "Input field"}
32
- ],
33
- "actions": [
34
- {"action": "click", "x": 100, "y": 200, "description": "Click button"},
35
- {"action": "type", "x": 150, "y": 300, "description": "Type in field"}
36
- ]
37
- }
38
-
39
- return json.dumps(result, indent=2)
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
- return f"Error processing image: {str(e)}"
 
 
 
43
 
44
  # Create Gradio interface
45
  iface = gr.Interface(
@@ -48,9 +152,9 @@ iface = gr.Interface(
48
  gr.Image(type="pil", label="Upload Screenshot"),
49
  gr.Textbox(label="Prompt/Goal", placeholder="What do you want to do?")
50
  ],
51
- outputs=gr.Textbox(label="Grounding Results", lines=10),
52
  title="UI-TARS Grounding Model",
53
- description="Upload a screenshot and describe your goal to get grounding results"
54
  )
55
 
56
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
3
  import torch
4
  from PIL import Image
5
  import io
6
  import base64
7
  import json
8
+ import numpy as np
9
 
10
+ # UI-TARS is a Qwen2.5-VL model - use the correct model class
11
  model_name = "ByteDance-Seed/UI-TARS-1.5-7B"
12
+
13
+ def load_model():
14
+ """Load UI-TARS model with proper configuration"""
15
+ try:
16
+ # UI-TARS requires specific handling for Qwen2.5-VL architecture
17
+ from transformers import Qwen2_5VLMForCausalLM, Qwen2_5VLMProcessor
18
+
19
+ # Load processor and model with proper configuration
20
+ processor = Qwen2_5VLMProcessor.from_pretrained(
21
+ model_name,
22
+ trust_remote_code=True
23
+ )
24
+
25
+ model = Qwen2_5VLMForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.float16, # Use half precision for memory efficiency
28
+ device_map="auto", # Automatically handle device placement
29
+ trust_remote_code=True,
30
+ low_cpu_mem_usage=True
31
+ )
32
+
33
+ print("βœ… UI-TARS model loaded successfully!")
34
+ return model, processor
35
+
36
+ except Exception as e:
37
+ print(f"❌ Error loading UI-TARS: {e}")
38
+ print("Falling back to alternative approach...")
39
+
40
+ try:
41
+ # Alternative: Use AutoModel with trust_remote_code
42
+ processor = AutoProcessor.from_pretrained(
43
+ model_name,
44
+ trust_remote_code=True
45
+ )
46
+
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ model_name,
49
+ torch_dtype=torch.float16,
50
+ device_map="auto",
51
+ trust_remote_code=True,
52
+ low_cpu_mem_usage=True
53
+ )
54
+
55
+ print("βœ… UI-TARS loaded with AutoModelForCausalLM")
56
+ return model, processor
57
+
58
+ except Exception as e2:
59
+ print(f"❌ Alternative approach failed: {e2}")
60
+ return None, None
61
+
62
+ # Load model at startup
63
+ print("πŸ”„ Loading UI-TARS model...")
64
+ model, processor = load_model()
65
 
66
  def process_grounding(image, prompt):
67
  """
68
  Process image with UI-TARS grounding model
 
69
  """
70
  try:
71
+ if model is None or processor is None:
72
+ return json.dumps({
73
+ "error": "Model not loaded",
74
+ "status": "failed"
75
+ }, indent=2)
76
+
77
  # Convert image to PIL if needed
78
  if isinstance(image, str):
 
79
  image_data = base64.b64decode(image)
80
  image = Image.open(io.BytesIO(image_data))
81
 
82
+ # Prepare prompt for UI-TARS
83
+ # UI-TARS expects specific formatting for grounding tasks
84
+ formatted_prompt = f"""<image>
85
+ Please analyze this screenshot and provide grounding information for the following task: {prompt}
86
+
87
+ Please identify UI elements and provide:
88
+ 1. Element locations (x, y coordinates)
89
+ 2. Element types (button, text field, etc.)
90
+ 3. Recommended actions (click, type, etc.)
91
+ 4. Confidence scores
92
+
93
+ Format your response as JSON with the following structure:
94
+ {{
95
+ "elements": [
96
+ {{"type": "button", "x": 100, "y": 200, "text": "Click me", "confidence": 0.9}}
97
+ ],
98
+ "actions": [
99
+ {{"action": "click", "x": 100, "y": 200, "description": "Click button"}}
100
+ ]
101
+ }}"""
102
+
103
+ # Prepare inputs for the model
104
+ inputs = processor(
105
+ text=formatted_prompt,
106
+ images=image,
107
+ return_tensors="pt"
108
+ )
109
+
110
+ # Move inputs to same device as model
111
+ device = next(model.parameters()).device
112
+ inputs = {k: v.to(device) for k, v in inputs.items()}
113
+
114
+ # Generate grounding results
115
+ with torch.no_grad():
116
+ outputs = model.generate(
117
+ **inputs,
118
+ max_new_tokens=512,
119
+ do_sample=True,
120
+ temperature=0.7,
121
+ top_p=0.9,
122
+ repetition_penalty=1.1
123
+ )
124
+
125
+ # Decode outputs
126
+ result_text = processor.decode(outputs[0], skip_special_tokens=True)
127
+
128
+ # Extract the response part after the prompt
129
+ response_start = result_text.find('{')
130
+ if response_start != -1:
131
+ response_json = result_text[response_start:]
132
+ try:
133
+ # Try to parse as JSON
134
+ parsed_result = json.loads(response_json)
135
+ return json.dumps(parsed_result, indent=2)
136
+ except json.JSONDecodeError:
137
+ # If JSON parsing fails, return the raw text
138
+ return f"Raw Response:\n{result_text}\n\nNote: Response could not be parsed as JSON"
139
+ else:
140
+ return f"Model Response:\n{result_text}"
141
+
142
  except Exception as e:
143
+ return json.dumps({
144
+ "error": f"Error processing image: {str(e)}",
145
+ "status": "failed"
146
+ }, indent=2)
147
 
148
  # Create Gradio interface
149
  iface = gr.Interface(
 
152
  gr.Image(type="pil", label="Upload Screenshot"),
153
  gr.Textbox(label="Prompt/Goal", placeholder="What do you want to do?")
154
  ],
155
+ outputs=gr.Textbox(label="Grounding Results", lines=15),
156
  title="UI-TARS Grounding Model",
157
+ description="Upload a screenshot and describe your goal to get grounding results from UI-TARS"
158
  )
159
 
160
  iface.launch()
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  transformers
2
  torch
 
 
 
3
  Pillow
4
  gradio
 
1
  transformers
2
  torch
3
+ torchvision
4
+ accelerate
5
+ numpy
6
  Pillow
7
  gradio