Ram07 commited on
Commit
b3a7580
·
1 Parent(s): 86e6f10

Add ColModernVBert image search app

Browse files
Files changed (3) hide show
  1. README.md +38 -11
  2. app.py +189 -0
  3. requirements.txt +11 -0
README.md CHANGED
@@ -1,13 +1,40 @@
1
- ---
2
- title: Image Search Colmodernvbert
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.48.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # Image Search with ColModernVBert
2
+
3
+ A multimodal image search demo using ColModernVBert for cross-modal retrieval between text queries and images.
4
+
5
+ ## Features
6
+
7
+ - **Multimodal Search**: Query images using natural language text
8
+ - **ImageNet-1K Dataset**: Searches through 1000 diverse validation images
9
+ - **Real-time Indexing**: Automatically indexes 1000 images on startup
10
+ - **512x512 Optimization**: Images resized for optimal model performance
11
+ - **Multiprocessing**: Fast parallel image preprocessing
12
+ - **Upload Support**: Upload custom images or switch datasets
13
+
14
+ ## Usage
15
+
16
+ Enter text queries like:
17
+ - `"dog"` → Various dog breeds
18
+ - `"sports car"` → Different car models
19
+ - `"musical instrument"` → Guitars, pianos, violins
20
+ - `"food"` → Fruits, vegetables, dishes
21
+ - `"nature"` → Trees, flowers, landscapes
22
+
23
+ Adjust the Top-K slider to control the number of results returned.
24
+
25
+ ## Technical Details
26
+
27
+ - **Model**: ColModernVBert (ModernVBERT/colmodernvbert)
28
+ - **Dataset**: ImageNet-1K validation set (1000 images)
29
+ - **Image Size**: 512x512 pixels
30
+ - **Embeddings**: Cosine similarity between text and image embeddings
31
+
32
+ ## Performance
33
+
34
+ - **Indexing**: ~2-3 minutes for 1000 images
35
+ - **Search**: Near-instant results
36
+ - **Memory**: Optimized for Space hardware limits
37
+
38
  ---
39
 
40
+ Built with Gradio and ColModernVBert for demonstration purposes.
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from huggingface_hub import hf_hub_download
6
+ from colpali_engine.models import ColModernVBert, ColModernVBertProcessor
7
+ from colpali_engine.utils.torch_utils import get_torch_device
8
+ from datasets import load_dataset
9
+ import multiprocessing as mp
10
+ from functools import partial
11
+ import tqdm
12
+
13
+ MODEL_ID = "ModernVBERT/colmodernvbert"
14
+
15
+ device = get_torch_device("auto")
16
+ processor = ColModernVBertProcessor.from_pretrained(MODEL_ID)
17
+ model = ColModernVBert.from_pretrained(
18
+ MODEL_ID,
19
+ torch_dtype=torch.float32,
20
+ trust_remote_code=True,
21
+ )
22
+ model.to(device)
23
+ model.eval()
24
+
25
+ INDEX_IMAGES = []
26
+ INDEX_EMB = None
27
+ TARGET_SIZE = (512, 512)
28
+ NUM_WORKERS = mp.cpu_count() // 2 # Use half the CPU cores to avoid contention
29
+
30
+
31
+ def _ensure_size(img: Image.Image) -> Image.Image:
32
+ if img.size != TARGET_SIZE:
33
+ return img.resize(TARGET_SIZE, Image.BICUBIC)
34
+ return img
35
+
36
+
37
+ def load_sample_images():
38
+ paths = [
39
+ hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/rococo.jpg", repo_type="space"),
40
+ hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/astronaut.png", repo_type="space"),
41
+ hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/cat.png", repo_type="space"),
42
+ ]
43
+ return [_ensure_size(Image.open(p).convert("RGB")) for p in paths]
44
+
45
+
46
+ def build_index(images):
47
+ global INDEX_IMAGES, INDEX_EMB
48
+ processed = [_ensure_size(img.convert("RGB")) for img in images]
49
+ INDEX_IMAGES = processed
50
+ with torch.inference_mode():
51
+ inputs = processor.process_images(processed)
52
+ inputs.to(device)
53
+ emb = model(**inputs)
54
+ INDEX_EMB = torch.nn.functional.normalize(emb, dim=-1)
55
+ return f"Indexed {len(processed)} images (resized to {TARGET_SIZE[0]}x{TARGET_SIZE[1]})"
56
+
57
+
58
+ def ensure_index():
59
+ if not INDEX_IMAGES:
60
+ # Auto-load 1000 images from ImageNet-1K dataset
61
+ print("Auto-loading 1000 images from ImageNet-1K dataset (this may take a few minutes)...")
62
+ builder_status = build_index_from_dataset("imagenet-1k", "validation", "image", 1000, 64)
63
+ print(f"Auto-indexing completed: {builder_status}")
64
+ return builder_status
65
+
66
+
67
+ def search(query, top_k=3):
68
+ ensure_index()
69
+ with torch.inference_mode():
70
+ q_inputs = processor.process_texts([query])
71
+ q_inputs.to(device)
72
+ q_emb = model(**q_inputs)
73
+ q_emb = torch.nn.functional.normalize(q_emb, dim=-1)
74
+ sims = (q_emb @ INDEX_EMB.T).squeeze(0)
75
+ vals, idxs = torch.topk(sims, k=min(top_k, len(INDEX_IMAGES)))
76
+ results = [(INDEX_IMAGES[i], f"score={vals[j].item():.4f}") for j, i in enumerate(idxs.tolist())]
77
+ return results
78
+
79
+
80
+ def upload_and_build(files):
81
+ if not files:
82
+ return "No files uploaded"
83
+ images = [_ensure_size(Image.open(f.name).convert("RGB")) for f in files]
84
+ return build_index(images)
85
+
86
+
87
+ def _preprocess_image_worker(args):
88
+ """Worker function for preprocessing images in parallel"""
89
+ row_data = args
90
+ if isinstance(row_data, tuple):
91
+ row, image_col, index = row_data
92
+ else:
93
+ # Handle direct image data
94
+ row, image_col = args
95
+ index = 0
96
+
97
+ if image_col not in row or row[image_col] is None:
98
+ return None, index
99
+
100
+ img = row[image_col]
101
+ if hasattr(img, "convert"):
102
+ img = img.convert("RGB")
103
+ img = _ensure_size(img)
104
+ return img, index
105
+
106
+
107
+ def build_index_from_dataset(repo_id: str, split: str = "train", image_col: str = "image", limit: int = 500, batch_size: int = 64):
108
+ global INDEX_IMAGES, INDEX_EMB
109
+ ds = load_dataset(repo_id, split=split, streaming=True)
110
+
111
+ # Step 1: Collect images in parallel
112
+ print(f"Loading and preprocessing {limit} images using {NUM_WORKERS} workers...")
113
+ image_data = []
114
+ count = 0
115
+
116
+ # Collect raw data first
117
+ for row in ds:
118
+ if image_col not in row or row[image_col] is None:
119
+ continue
120
+ image_data.append((row, image_col, count))
121
+ count += 1
122
+ if len(image_data) >= limit:
123
+ break
124
+
125
+ # Preprocess images in parallel
126
+ with mp.Pool(NUM_WORKERS) as pool:
127
+ results = list(tqdm.tqdm(
128
+ pool.imap(_preprocess_image_worker, image_data),
129
+ total=len(image_data),
130
+ desc="Preprocessing images"
131
+ ))
132
+
133
+ # Filter out None results and sort by index
134
+ valid_results = [(img, idx) for img, idx in results if img is not None]
135
+ valid_results.sort(key=lambda x: x[1]) # Sort by original index
136
+ images = [img for img, _ in valid_results]
137
+
138
+ print(f"Successfully preprocessed {len(images)} images")
139
+
140
+ # Step 2: Embed images in batches (GPU intensive, keep single-threaded)
141
+ print("Computing embeddings...")
142
+ all_emb = []
143
+ with torch.inference_mode():
144
+ for i in tqdm.tqdm(range(0, len(images), batch_size), desc="Computing embeddings"):
145
+ batch = images[i:i+batch_size]
146
+ if not batch:
147
+ continue
148
+ inputs = processor.process_images(batch)
149
+ inputs.to(device)
150
+ emb = model(**inputs)
151
+ all_emb.append(torch.nn.functional.normalize(emb, dim=-1).to("cpu"))
152
+
153
+ INDEX_IMAGES = images
154
+ INDEX_EMB = torch.cat(all_emb, dim=0).to(device)
155
+ return f"Indexed {len(images)} images from {repo_id}:{split} (resized to {TARGET_SIZE[0]}x{TARGET_SIZE[1]}) - Used {NUM_WORKERS} workers"
156
+
157
+
158
+ with gr.Blocks(theme='default') as demo:
159
+ gr.Markdown("# ColModernVBert Image Search (Minimal Demo)")
160
+ gr.Markdown("⚠️ **First load takes ~2-3 minutes**: Auto-indexing 1000 images from ImageNet-1K validation set")
161
+ with gr.Row():
162
+ with gr.Column():
163
+ query = gr.Textbox(label="Text query", value="a baroque painting")
164
+ topk = gr.Slider(1, 8, value=3, step=1, label="Top-K")
165
+ btn = gr.Button("Search")
166
+ out = gr.Gallery(label="Results", columns=3, rows=1)
167
+ with gr.Column():
168
+ up = gr.File(file_count="multiple", type="filepath", label="Upload images to index")
169
+ status = gr.Textbox(label="Index status", interactive=False)
170
+ build = gr.Button("Build Index")
171
+ with gr.Accordion("Load from HF dataset (replace auto-loaded images)", open=True):
172
+ repo = gr.Textbox(label="Dataset repo_id", value="imagenet-1k")
173
+ split = gr.Textbox(label="Split", value="validation")
174
+ img_col = gr.Textbox(label="Image column", value="image")
175
+ lim = gr.Number(label="Max images", value=1000, precision=0)
176
+ bsize = gr.Number(label="Batch size", value=64, precision=0)
177
+ build_ds = gr.Button("Build Index from Dataset")
178
+ status_ds = gr.Textbox(label="Index status", interactive=False)
179
+
180
+ btn.click(fn=search, inputs=[query, topk], outputs=out)
181
+ build.click(fn=upload_and_build, inputs=[up], outputs=status)
182
+ build_ds.click(lambda r,s,c,l,b: build_index_from_dataset(r, s, c, int(l), int(b)), inputs=[repo, split, img_col, lim, bsize], outputs=status_ds)
183
+
184
+ if __name__ == "__main__":
185
+ # Start indexing in background (if None, UI still starts; indexing happens on first search)
186
+ status_msg = ensure_index()
187
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
188
+
189
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub>=0.35.3
2
+ torch>=2.2.0
3
+ torchvision>=0.17.0
4
+ transformers>=4.40.2
5
+ pillow>=10.3.0
6
+ accelerate>=0.29.0
7
+ gradio>=4.44.0
8
+ datasets>=2.20.0
9
+ tqdm>=4.60.0
10
+
11
+ # flash-attn>=2.0.0 # Optional: requires CUDA toolkit