Vladyslav Humennyy commited on
Commit
edb7715
·
1 Parent(s): b7b5970

Fix vision processing by using PIL Images instead of file paths

Browse files

Previously, the processor was receiving file paths instead of actual PIL Image objects, causing vision functionality to fail. This commit updates the image handling pipeline:

- Changed _ensure_image_object() to return PIL Images instead of paths
- Updated user() to store PIL Images with proper multi-modal format
- Added _prepare_processor_history() to format messages correctly for the processor
- Updated helper functions to handle new image format

Files changed (1) hide show
  1. app.py +73 -19
app.py CHANGED
@@ -53,25 +53,23 @@ def load_model():
53
  model, tokenizer, processor, device = load_model()
54
 
55
 
56
- def _ensure_image_path(image_data: Any) -> str | None:
57
- """Return a filesystem path for the provided image data."""
58
  if image_data is None:
59
  return None
60
 
61
- if isinstance(image_data, str) and os.path.exists(image_data):
62
- return image_data
63
-
64
  try:
65
  from PIL import Image
66
  except ImportError: # pragma: no cover - PIL is bundled with Gradio's image component
67
  return None
68
 
 
69
  if isinstance(image_data, Image.Image):
70
- fd, tmp_path = tempfile.mkstemp(suffix=".png")
71
- os.close(fd)
72
- image_format = image_data.format or "PNG"
73
- image_data.save(tmp_path, format=image_format)
74
- return tmp_path
75
 
76
  return None
77
 
@@ -83,13 +81,26 @@ def user(user_message, image_data, history: list):
83
  has_content = False
84
 
85
  stripped_message = user_message.strip()
86
- if stripped_message:
 
 
 
 
 
 
 
 
 
 
 
 
87
  updated_history.append({"role": "user", "content": stripped_message})
88
  has_content = True
89
-
90
- image_path = _ensure_image_path(image_data)
91
- if image_path is not None:
92
- updated_history.append({"role": "user", "content": {"path": image_path, "alt_text": "User uploaded image"}})
 
93
  has_content = True
94
 
95
  if not has_content:
@@ -112,7 +123,7 @@ def append_example_message(x: gr.SelectData, history):
112
  def _message_contains_image(message: dict[str, Any]) -> bool:
113
  content = message.get("content")
114
  if isinstance(content, dict):
115
- if "path" in content:
116
  return True
117
  if content.get("type") in {"image", "image_url"}:
118
  return True
@@ -131,6 +142,8 @@ def _content_to_text(content: Any) -> str:
131
  alt_text = content.get("alt_text")
132
  placeholder = alt_text or os.path.basename(content["path"]) or "image"
133
  return f"[image: {placeholder}]"
 
 
134
  if content.get("type") == "image_url":
135
  image_url = content.get("image_url")
136
  if isinstance(image_url, dict):
@@ -147,8 +160,7 @@ def _content_to_text(content: Any) -> str:
147
  if item_type == "text":
148
  text_parts.append(item.get("text", ""))
149
  elif item_type == "image":
150
- alt_text = item.get("alt_text")
151
- text_parts.append(f"[image: {alt_text}]" if alt_text else "[image]")
152
  elif item_type == "image_url":
153
  image_url = item.get("image_url")
154
  if isinstance(image_url, dict):
@@ -188,6 +200,47 @@ def _prepare_text_history(history: list[dict[str, Any]]) -> list[dict[str, str]]
188
  return text_history
189
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  @spaces.GPU
192
  def bot(
193
  history: list[dict[str, Any]]
@@ -223,8 +276,9 @@ def bot(
223
 
224
  if processor is not None and any(_message_contains_image(msg) for msg in history):
225
  try:
 
226
  model_inputs = processor(
227
- messages=history,
228
  return_tensors="pt",
229
  add_generation_prompt=True,
230
  ).to(model.device)
 
53
  model, tokenizer, processor, device = load_model()
54
 
55
 
56
+ def _ensure_image_object(image_data: Any) -> Any | None:
57
+ """Return a PIL Image object for the provided image data."""
58
  if image_data is None:
59
  return None
60
 
 
 
 
61
  try:
62
  from PIL import Image
63
  except ImportError: # pragma: no cover - PIL is bundled with Gradio's image component
64
  return None
65
 
66
+ # Already a PIL Image
67
  if isinstance(image_data, Image.Image):
68
+ return image_data
69
+
70
+ # Load from path
71
+ if isinstance(image_data, str) and os.path.exists(image_data):
72
+ return Image.open(image_data)
73
 
74
  return None
75
 
 
81
  has_content = False
82
 
83
  stripped_message = user_message.strip()
84
+ image_obj = _ensure_image_object(image_data)
85
+
86
+ # If we have both text and image, combine them in a single message
87
+ if stripped_message and image_obj is not None:
88
+ updated_history.append({
89
+ "role": "user",
90
+ "content": [
91
+ {"type": "text", "text": stripped_message},
92
+ {"type": "image", "image": image_obj}
93
+ ]
94
+ })
95
+ has_content = True
96
+ elif stripped_message:
97
  updated_history.append({"role": "user", "content": stripped_message})
98
  has_content = True
99
+ elif image_obj is not None:
100
+ updated_history.append({
101
+ "role": "user",
102
+ "content": [{"type": "image", "image": image_obj}]
103
+ })
104
  has_content = True
105
 
106
  if not has_content:
 
123
  def _message_contains_image(message: dict[str, Any]) -> bool:
124
  content = message.get("content")
125
  if isinstance(content, dict):
126
+ if "path" in content or "image" in content:
127
  return True
128
  if content.get("type") in {"image", "image_url"}:
129
  return True
 
142
  alt_text = content.get("alt_text")
143
  placeholder = alt_text or os.path.basename(content["path"]) or "image"
144
  return f"[image: {placeholder}]"
145
+ if "image" in content:
146
+ return "[image]"
147
  if content.get("type") == "image_url":
148
  image_url = content.get("image_url")
149
  if isinstance(image_url, dict):
 
160
  if item_type == "text":
161
  text_parts.append(item.get("text", ""))
162
  elif item_type == "image":
163
+ text_parts.append("[image]")
 
164
  elif item_type == "image_url":
165
  image_url = item.get("image_url")
166
  if isinstance(image_url, dict):
 
200
  return text_history
201
 
202
 
203
+ def _prepare_processor_history(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
204
+ """Prepare history for processor with proper image format."""
205
+ processor_history = []
206
+
207
+ for message in history:
208
+ role = message.get("role", "user")
209
+ content = message.get("content")
210
+
211
+ # Handle different content formats
212
+ if isinstance(content, str):
213
+ # Simple text message
214
+ processor_history.append({"role": role, "content": content})
215
+ elif isinstance(content, list):
216
+ # Multi-modal content (text + images)
217
+ formatted_content = []
218
+ for item in content:
219
+ if isinstance(item, dict):
220
+ item_type = item.get("type")
221
+ if item_type == "text":
222
+ formatted_content.append({"type": "text", "text": item.get("text", "")})
223
+ elif item_type == "image":
224
+ # Include the PIL Image directly
225
+ formatted_content.append({"type": "image", "image": item.get("image")})
226
+ if formatted_content:
227
+ processor_history.append({"role": role, "content": formatted_content})
228
+ elif isinstance(content, dict):
229
+ # Legacy format or single image
230
+ if "image" in content:
231
+ processor_history.append({
232
+ "role": role,
233
+ "content": [{"type": "image", "image": content["image"]}]
234
+ })
235
+ else:
236
+ # Try to extract text
237
+ text = _content_to_text(content)
238
+ if text:
239
+ processor_history.append({"role": role, "content": text})
240
+
241
+ return processor_history
242
+
243
+
244
  @spaces.GPU
245
  def bot(
246
  history: list[dict[str, Any]]
 
276
 
277
  if processor is not None and any(_message_contains_image(msg) for msg in history):
278
  try:
279
+ processor_history = _prepare_processor_history(history)
280
  model_inputs = processor(
281
+ messages=processor_history,
282
  return_tensors="pt",
283
  add_generation_prompt=True,
284
  ).to(model.device)