Ram07 commited on
Commit
7cae60e
Β·
1 Parent(s): 63085b9

Add ColModernVBert image search app

Browse files
Files changed (2) hide show
  1. app.py +183 -20
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import gradio as gr
3
  import torch
 
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
  from colpali_engine.models import ColModernVBert, ColModernVBertProcessor
@@ -9,6 +10,10 @@ from datasets import load_dataset
9
  import multiprocessing as mp
10
  from functools import partial
11
  import tqdm
 
 
 
 
12
 
13
  MODEL_ID = "ModernVBERT/colmodernvbert"
14
 
@@ -83,6 +88,121 @@ def upload_and_build(files):
83
  images = [_ensure_size(Image.open(f.name).convert("RGB")) for f in files]
84
  return build_index(images)
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def _preprocess_image_worker(args):
88
  """Worker function for preprocessing images in parallel"""
@@ -156,27 +276,70 @@ def build_index_from_dataset(repo_id: str, split: str = "train", image_col: str
156
 
157
 
158
  with gr.Blocks(theme='default') as demo:
159
- gr.Markdown("# ColModernVBert Image Search (Minimal Demo)")
160
- gr.Markdown("⚠️ **First load takes ~2-3 minutes**: Auto-indexing 1000 images from ImageNet-1K validation set")
161
- with gr.Row():
162
- with gr.Column():
163
- query = gr.Textbox(label="Text query", value="a baroque painting")
164
- topk = gr.Slider(1, 8, value=3, step=1, label="Top-K")
165
- btn = gr.Button("Search")
166
- out = gr.Gallery(label="Results", columns=3, rows=1)
167
- with gr.Column():
168
- up = gr.File(file_count="multiple", type="filepath", label="Upload images to index")
169
- status = gr.Textbox(label="Index status", interactive=False)
170
- build = gr.Button("Build Index")
171
- with gr.Accordion("Load from HF dataset (replace auto-loaded images)", open=True):
172
- repo = gr.Textbox(label="Dataset repo_id", value="imagenet-1k")
173
- split = gr.Textbox(label="Split", value="validation")
174
- img_col = gr.Textbox(label="Image column", value="image")
175
- lim = gr.Number(label="Max images", value=1000, precision=0)
176
- bsize = gr.Number(label="Batch size", value=64, precision=0)
177
- build_ds = gr.Button("Build Index from Dataset")
178
- status_ds = gr.Textbox(label="Index status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
 
180
  btn.click(fn=search, inputs=[query, topk], outputs=out)
181
  build.click(fn=upload_and_build, inputs=[up], outputs=status)
182
  build_ds.click(lambda r,s,c,l,b: build_index_from_dataset(r, s, c, int(l), int(b)), inputs=[repo, split, img_col, lim, bsize], outputs=status_ds)
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ import torch.nn.functional as F
5
  from PIL import Image
6
  from huggingface_hub import hf_hub_download
7
  from colpali_engine.models import ColModernVBert, ColModernVBertProcessor
 
10
  import multiprocessing as mp
11
  from functools import partial
12
  import tqdm
13
+ import matplotlib.pyplot as plt
14
+ import base64
15
+ from io import BytesIO
16
+ import numpy as np
17
 
18
  MODEL_ID = "ModernVBERT/colmodernvbert"
19
 
 
88
  images = [_ensure_size(Image.open(f.name).convert("RGB")) for f in files]
89
  return build_index(images)
90
 
91
+ def visualize_attention(text_embed, img_embeds, attention_mask=None):
92
+ """Visualize attention between text and image embeddings"""
93
+ # Normalize embeddings
94
+ text_norm = F.normalize(text_embed, dim=-1)
95
+ img_norm = F.normalize(img_embeds, dim=-1)
96
+
97
+ # Compute attention scores
98
+ attention_scores = torch.matmul(text_norm, img_norm.transpose(-2, -1))
99
+
100
+ # Create attention heatmap
101
+ scores = attention_scores.squeeze().detach().cpu().numpy()
102
+
103
+ fig, ax = plt.subplots(figsize=(10, 6))
104
+ im = ax.imshow(scores, cmap='Yl_orange', aspect='auto')
105
+
106
+ ax.set_title('Text-Image Attention Map')
107
+ ax.set_xlabel('Image Embeddings')
108
+ ax.set_ylabel('Text Embeddings')
109
+
110
+ # Add colorbar
111
+ plt.colorbar(im, ax=ax)
112
+ plt.tight_layout()
113
+
114
+ # Convert to base64 for Gradio
115
+ buf = BytesIO()
116
+ fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
117
+ buf.seek(0)
118
+ img_str = base64.b64encode(buf.getvalue()).decode()
119
+ plt.close(fig)
120
+
121
+ return f"data:image/png;base64,{img_str}"
122
+
123
+ def test_text_image_alignment(text_inputs, image_files, comparison_text=""):
124
+ """Test alignment between uploaded text and images with real-time comparison"""
125
+ if len(image_files) < 2:
126
+ return "❌ At least 2 images required for comparison", None, "Upload 2+ images to compare"
127
+
128
+ if not text_inputs.strip():
129
+ return "❌ Text input required", None, "Enter text to test alignment"
130
+
131
+ try:
132
+ # Process uploaded images
133
+ images = []
134
+ for f in image_files:
135
+ img = Image.open(f.name).convert("RGB")
136
+ img = _ensure_size(img)
137
+ images.append(img)
138
+
139
+ with torch.inference_mode():
140
+ # Text embedding
141
+ text_processed = processor.process_texts([text_inputs])
142
+ text_processed.to(device)
143
+ text_embed = model(**text_processed)
144
+ text_embed = F.normalize(text_embed, dim=-1)
145
+
146
+ # Image embeddings
147
+ img_processed = processor.process_images(images)
148
+ img_processed.to(device)
149
+ img_embeds = model(**img_processed)
150
+ img_embeds = F.normalize(img_embeds, dim=-1)
151
+
152
+ # Compute similarities
153
+ similarities = F.cosine_similarity(text_embed, img_embeds, dim=-1)
154
+
155
+ # Create comparison results
156
+ results = []
157
+ attention_viz = None
158
+
159
+ for i, (img, sim_score) in enumerate(zip(images, similarities)):
160
+ sim_val = sim_score.item()
161
+ caption = f"Similarity: {sim_val:.4f}"
162
+
163
+ # Score interpretation
164
+ if sim_val > 0.7:
165
+ interpretation = "🟒 Strong match"
166
+ elif sim_val > 0.4:
167
+ interpretation = "🟑 Moderate match"
168
+ else:
169
+ interpretation = "πŸ”΄ Weak match"
170
+
171
+ results.append((img, f"{caption} - {interpretation}"))
172
+
173
+ # Generate attention visualization
174
+ if len(results) >= 2:
175
+ attention_viz = visualize_attention(text_embed, img_embeds)
176
+
177
+ # Detailed analysis
178
+ analysis = f"""
179
+ **Real-time Testing Results:**
180
+
181
+ πŸ“ **Query Text:** "{text_inputs}"
182
+ πŸ–ΌοΈ **Images Tested:** {len(images)}
183
+
184
+ **Similarity Scores:**
185
+ """
186
+ for i, sim_val in enumerate(similarities):
187
+ analysis += f"- Image {i+1}: {sim_val:.4f}\n"
188
+
189
+ analysis += f"""
190
+ **Best Match:** Image #{torch.argmax(similarities).item() + 1} (score: {similarities.max():.4f})
191
+ **Average Score:** {similarities.mean():.4f}
192
+ **Score Range:** {similarities.min():.4f} - {similarities.max():.4f}
193
+
194
+ **Model Training Evidence:**
195
+ βœ… Text understanding: Model processes natural language
196
+ βœ… Image understanding: Model processes visual content
197
+ βœ… Cross-modal alignment: Computes meaningful similarities
198
+ βœ… Attention mechanism: Learns text-image relationships
199
+ """
200
+
201
+ return analysis, results, attention_viz
202
+
203
+ except Exception as e:
204
+ return f"❌ Error during testing: {str(e)}", None, None
205
+
206
 
207
  def _preprocess_image_worker(args):
208
  """Worker function for preprocessing images in parallel"""
 
276
 
277
 
278
  with gr.Blocks(theme='default') as demo:
279
+ with gr.Tabs():
280
+ # Tab 1: Image Search
281
+ with gr.Tab("πŸ–ΌοΈ Image Search"):
282
+ gr.Markdown("# ColModernVBert Image Search")
283
+ gr.Markdown("⚠️ **First load takes ~2-3 minutes**: Auto-indexing 1000 images from ImageNet-1K validation set")
284
+ with gr.Row():
285
+ with gr.Column():
286
+ query = gr.Textbox(label="Text query", value="a baroque painting")
287
+ topk = gr.Slider(1, 8, value=3, step=1, label="Top-K")
288
+ btn = gr.Button("Search")
289
+ out = gr.Gallery(label="Results")
290
+
291
+ # Tab 2: Real-time Testing & Attention Visualization
292
+ with gr.Tab("πŸ§ͺ Model Testing"):
293
+ gr.Markdown("# Real-time Text-Image Alignment Testing")
294
+ gr.Markdown("Upload **minimum 2 images** and test with text queries to analyze model behavior")
295
+
296
+ with gr.Row():
297
+ with gr.Column():
298
+ test_text = gr.Textbox(
299
+ label="Test Query Text",
300
+ placeholder="Enter text like 'red car', 'dog playing', 'modern architecture'",
301
+ value="red sports car"
302
+ )
303
+ test_images = gr.File(
304
+ file_count="multiple",
305
+ file_types=["image"],
306
+ label="Upload Images (Min 2 required)"
307
+ )
308
+ test_btn = gr.Button("🧠 Test Model Alignment", variant="primary")
309
+
310
+ with gr.Column():
311
+ attention_viz = gr.Image(label="Attention Heatmap", type="pil")
312
+
313
+ with gr.Row():
314
+ test_results = gr.Gallery(label="Image Similarity Results (>2 images shown)", columns=2)
315
+
316
+ test_analysis = gr.Markdown(label="Detailed Analysis")
317
+
318
+ test_btn.click(
319
+ fn=test_text_image_alignment,
320
+ inputs=[test_text, test_images],
321
+ outputs=[test_analysis, test_results, attention_viz]
322
+ )
323
+
324
+ # Tab 3: Dataset Management
325
+ with gr.Tab("πŸ“š Dataset Management"):
326
+ gr.Markdown("# Manage Image Index")
327
+ with gr.Row():
328
+ with gr.Column():
329
+ up = gr.File(file_count="multiple", type="filepath", label="Upload images to index")
330
+ status = gr.Textbox(label="Index status", interactive=False)
331
+ build = gr.Button("Build Index")
332
+
333
+ with gr.Accordion("Load from HF dataset", open=True):
334
+ repo = gr.Textbox(label="Dataset repo_id", value="imagenet-1k")
335
+ split = gr.Textbox(label="Split", value="validation")
336
+ img_col = gr.Textbox(label="Image column", value="image")
337
+ lim = gr.Number(label="Max images", value=1000, precision=0)
338
+ bsize = gr.Number(label="Batch size", value=64, precision=0)
339
+ build_ds = gr.Button("Build Index from Dataset")
340
+ status_ds = gr.Textbox(label="Index status", interactive=False)
341
 
342
+ # Event handlers
343
  btn.click(fn=search, inputs=[query, topk], outputs=out)
344
  build.click(fn=upload_and_build, inputs=[up], outputs=status)
345
  build_ds.click(lambda r,s,c,l,b: build_index_from_dataset(r, s, c, int(l), int(b)), inputs=[repo, split, img_col, lim, bsize], outputs=status_ds)
requirements.txt CHANGED
@@ -7,5 +7,6 @@ accelerate>=0.29.0
7
  gradio>=4.44.0
8
  datasets>=2.20.0
9
  tqdm>=4.60.0
 
10
 
11
  # flash-attn>=2.0.0 # Optional: requires CUDA toolkit
 
7
  gradio>=4.44.0
8
  datasets>=2.20.0
9
  tqdm>=4.60.0
10
+ matplotlib>=3.5.0
11
 
12
  # flash-attn>=2.0.0 # Optional: requires CUDA toolkit