Ram07's picture
Add ColModernVBert image search app
7cae60e
raw
history blame
13.5 kB
import os
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import hf_hub_download
from colpali_engine.models import ColModernVBert, ColModernVBertProcessor
from colpali_engine.utils.torch_utils import get_torch_device
from datasets import load_dataset
import multiprocessing as mp
from functools import partial
import tqdm
import matplotlib.pyplot as plt
import base64
from io import BytesIO
import numpy as np
MODEL_ID = "ModernVBERT/colmodernvbert"
device = get_torch_device("auto")
processor = ColModernVBertProcessor.from_pretrained(MODEL_ID)
model = ColModernVBert.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
trust_remote_code=True,
)
model.to(device)
model.eval()
INDEX_IMAGES = []
INDEX_EMB = None
TARGET_SIZE = (512, 512)
NUM_WORKERS = mp.cpu_count() // 2 # Use half the CPU cores to avoid contention
def _ensure_size(img: Image.Image) -> Image.Image:
if img.size != TARGET_SIZE:
return img.resize(TARGET_SIZE, Image.BICUBIC)
return img
def load_sample_images():
paths = [
hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/rococo.jpg", repo_type="space"),
hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/astronaut.png", repo_type="space"),
hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/cat.png", repo_type="space"),
]
return [_ensure_size(Image.open(p).convert("RGB")) for p in paths]
def build_index(images):
global INDEX_IMAGES, INDEX_EMB
processed = [_ensure_size(img.convert("RGB")) for img in images]
INDEX_IMAGES = processed
with torch.inference_mode():
inputs = processor.process_images(processed)
inputs.to(device)
emb = model(**inputs)
INDEX_EMB = torch.nn.functional.normalize(emb, dim=-1)
return f"Indexed {len(processed)} images (resized to {TARGET_SIZE[0]}x{TARGET_SIZE[1]})"
def ensure_index():
if not INDEX_IMAGES:
# Auto-load 1000 images from ImageNet-1K dataset
print("Auto-loading 1000 images from ImageNet-1K dataset (this may take a few minutes)...")
builder_status = build_index_from_dataset("imagenet-1k", "validation", "image", 1000, 64)
print(f"Auto-indexing completed: {builder_status}")
return builder_status
def search(query, top_k=3):
ensure_index()
with torch.inference_mode():
q_inputs = processor.process_texts([query])
q_inputs.to(device)
q_emb = model(**q_inputs)
q_emb = torch.nn.functional.normalize(q_emb, dim=-1)
sims = (q_emb @ INDEX_EMB.T).squeeze(0)
vals, idxs = torch.topk(sims, k=min(top_k, len(INDEX_IMAGES)))
results = [(INDEX_IMAGES[i], f"score={vals[j].item():.4f}") for j, i in enumerate(idxs.tolist())]
return results
def upload_and_build(files):
if not files:
return "No files uploaded"
images = [_ensure_size(Image.open(f.name).convert("RGB")) for f in files]
return build_index(images)
def visualize_attention(text_embed, img_embeds, attention_mask=None):
"""Visualize attention between text and image embeddings"""
# Normalize embeddings
text_norm = F.normalize(text_embed, dim=-1)
img_norm = F.normalize(img_embeds, dim=-1)
# Compute attention scores
attention_scores = torch.matmul(text_norm, img_norm.transpose(-2, -1))
# Create attention heatmap
scores = attention_scores.squeeze().detach().cpu().numpy()
fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(scores, cmap='Yl_orange', aspect='auto')
ax.set_title('Text-Image Attention Map')
ax.set_xlabel('Image Embeddings')
ax.set_ylabel('Text Embeddings')
# Add colorbar
plt.colorbar(im, ax=ax)
plt.tight_layout()
# Convert to base64 for Gradio
buf = BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
img_str = base64.b64encode(buf.getvalue()).decode()
plt.close(fig)
return f"data:image/png;base64,{img_str}"
def test_text_image_alignment(text_inputs, image_files, comparison_text=""):
"""Test alignment between uploaded text and images with real-time comparison"""
if len(image_files) < 2:
return "❌ At least 2 images required for comparison", None, "Upload 2+ images to compare"
if not text_inputs.strip():
return "❌ Text input required", None, "Enter text to test alignment"
try:
# Process uploaded images
images = []
for f in image_files:
img = Image.open(f.name).convert("RGB")
img = _ensure_size(img)
images.append(img)
with torch.inference_mode():
# Text embedding
text_processed = processor.process_texts([text_inputs])
text_processed.to(device)
text_embed = model(**text_processed)
text_embed = F.normalize(text_embed, dim=-1)
# Image embeddings
img_processed = processor.process_images(images)
img_processed.to(device)
img_embeds = model(**img_processed)
img_embeds = F.normalize(img_embeds, dim=-1)
# Compute similarities
similarities = F.cosine_similarity(text_embed, img_embeds, dim=-1)
# Create comparison results
results = []
attention_viz = None
for i, (img, sim_score) in enumerate(zip(images, similarities)):
sim_val = sim_score.item()
caption = f"Similarity: {sim_val:.4f}"
# Score interpretation
if sim_val > 0.7:
interpretation = "🟒 Strong match"
elif sim_val > 0.4:
interpretation = "🟑 Moderate match"
else:
interpretation = "πŸ”΄ Weak match"
results.append((img, f"{caption} - {interpretation}"))
# Generate attention visualization
if len(results) >= 2:
attention_viz = visualize_attention(text_embed, img_embeds)
# Detailed analysis
analysis = f"""
**Real-time Testing Results:**
πŸ“ **Query Text:** "{text_inputs}"
πŸ–ΌοΈ **Images Tested:** {len(images)}
**Similarity Scores:**
"""
for i, sim_val in enumerate(similarities):
analysis += f"- Image {i+1}: {sim_val:.4f}\n"
analysis += f"""
**Best Match:** Image #{torch.argmax(similarities).item() + 1} (score: {similarities.max():.4f})
**Average Score:** {similarities.mean():.4f}
**Score Range:** {similarities.min():.4f} - {similarities.max():.4f}
**Model Training Evidence:**
βœ… Text understanding: Model processes natural language
βœ… Image understanding: Model processes visual content
βœ… Cross-modal alignment: Computes meaningful similarities
βœ… Attention mechanism: Learns text-image relationships
"""
return analysis, results, attention_viz
except Exception as e:
return f"❌ Error during testing: {str(e)}", None, None
def _preprocess_image_worker(args):
"""Worker function for preprocessing images in parallel"""
row_data = args
if isinstance(row_data, tuple):
row, image_col, index = row_data
else:
# Handle direct image data
row, image_col = args
index = 0
if image_col not in row or row[image_col] is None:
return None, index
img = row[image_col]
if hasattr(img, "convert"):
img = img.convert("RGB")
img = _ensure_size(img)
return img, index
def build_index_from_dataset(repo_id: str, split: str = "train", image_col: str = "image", limit: int = 500, batch_size: int = 64):
global INDEX_IMAGES, INDEX_EMB
ds = load_dataset(repo_id, split=split, streaming=True)
# Step 1: Collect images in parallel
print(f"Loading and preprocessing {limit} images using {NUM_WORKERS} workers...")
image_data = []
count = 0
# Collect raw data first
for row in ds:
if image_col not in row or row[image_col] is None:
continue
image_data.append((row, image_col, count))
count += 1
if len(image_data) >= limit:
break
# Preprocess images in parallel
with mp.Pool(NUM_WORKERS) as pool:
results = list(tqdm.tqdm(
pool.imap(_preprocess_image_worker, image_data),
total=len(image_data),
desc="Preprocessing images"
))
# Filter out None results and sort by index
valid_results = [(img, idx) for img, idx in results if img is not None]
valid_results.sort(key=lambda x: x[1]) # Sort by original index
images = [img for img, _ in valid_results]
print(f"Successfully preprocessed {len(images)} images")
# Step 2: Embed images in batches (GPU intensive, keep single-threaded)
print("Computing embeddings...")
all_emb = []
with torch.inference_mode():
for i in tqdm.tqdm(range(0, len(images), batch_size), desc="Computing embeddings"):
batch = images[i:i+batch_size]
if not batch:
continue
inputs = processor.process_images(batch)
inputs.to(device)
emb = model(**inputs)
all_emb.append(torch.nn.functional.normalize(emb, dim=-1).to("cpu"))
INDEX_IMAGES = images
INDEX_EMB = torch.cat(all_emb, dim=0).to(device)
return f"Indexed {len(images)} images from {repo_id}:{split} (resized to {TARGET_SIZE[0]}x{TARGET_SIZE[1]}) - Used {NUM_WORKERS} workers"
with gr.Blocks(theme='default') as demo:
with gr.Tabs():
# Tab 1: Image Search
with gr.Tab("πŸ–ΌοΈ Image Search"):
gr.Markdown("# ColModernVBert Image Search")
gr.Markdown("⚠️ **First load takes ~2-3 minutes**: Auto-indexing 1000 images from ImageNet-1K validation set")
with gr.Row():
with gr.Column():
query = gr.Textbox(label="Text query", value="a baroque painting")
topk = gr.Slider(1, 8, value=3, step=1, label="Top-K")
btn = gr.Button("Search")
out = gr.Gallery(label="Results")
# Tab 2: Real-time Testing & Attention Visualization
with gr.Tab("πŸ§ͺ Model Testing"):
gr.Markdown("# Real-time Text-Image Alignment Testing")
gr.Markdown("Upload **minimum 2 images** and test with text queries to analyze model behavior")
with gr.Row():
with gr.Column():
test_text = gr.Textbox(
label="Test Query Text",
placeholder="Enter text like 'red car', 'dog playing', 'modern architecture'",
value="red sports car"
)
test_images = gr.File(
file_count="multiple",
file_types=["image"],
label="Upload Images (Min 2 required)"
)
test_btn = gr.Button("🧠 Test Model Alignment", variant="primary")
with gr.Column():
attention_viz = gr.Image(label="Attention Heatmap", type="pil")
with gr.Row():
test_results = gr.Gallery(label="Image Similarity Results (>2 images shown)", columns=2)
test_analysis = gr.Markdown(label="Detailed Analysis")
test_btn.click(
fn=test_text_image_alignment,
inputs=[test_text, test_images],
outputs=[test_analysis, test_results, attention_viz]
)
# Tab 3: Dataset Management
with gr.Tab("πŸ“š Dataset Management"):
gr.Markdown("# Manage Image Index")
with gr.Row():
with gr.Column():
up = gr.File(file_count="multiple", type="filepath", label="Upload images to index")
status = gr.Textbox(label="Index status", interactive=False)
build = gr.Button("Build Index")
with gr.Accordion("Load from HF dataset", open=True):
repo = gr.Textbox(label="Dataset repo_id", value="imagenet-1k")
split = gr.Textbox(label="Split", value="validation")
img_col = gr.Textbox(label="Image column", value="image")
lim = gr.Number(label="Max images", value=1000, precision=0)
bsize = gr.Number(label="Batch size", value=64, precision=0)
build_ds = gr.Button("Build Index from Dataset")
status_ds = gr.Textbox(label="Index status", interactive=False)
# Event handlers
btn.click(fn=search, inputs=[query, topk], outputs=out)
build.click(fn=upload_and_build, inputs=[up], outputs=status)
build_ds.click(lambda r,s,c,l,b: build_index_from_dataset(r, s, c, int(l), int(b)), inputs=[repo, split, img_col, lim, bsize], outputs=status_ds)
if __name__ == "__main__":
# Start indexing in background (if None, UI still starts; indexing happens on first search)
status_msg = ensure_index()
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))