import os import gradio as gr import torch import torch.nn.functional as F from PIL import Image from huggingface_hub import hf_hub_download from colpali_engine.models import ColModernVBert, ColModernVBertProcessor from colpali_engine.utils.torch_utils import get_torch_device from datasets import load_dataset import multiprocessing as mp from functools import partial import tqdm import matplotlib.pyplot as plt import base64 from io import BytesIO import numpy as np MODEL_ID = "ModernVBERT/colmodernvbert" device = get_torch_device("auto") processor = ColModernVBertProcessor.from_pretrained(MODEL_ID) model = ColModernVBert.from_pretrained( MODEL_ID, torch_dtype=torch.float32, trust_remote_code=True, ) model.to(device) model.eval() INDEX_IMAGES = [] INDEX_EMB = None TARGET_SIZE = (512, 512) NUM_WORKERS = mp.cpu_count() // 2 # Use half the CPU cores to avoid contention def _ensure_size(img: Image.Image) -> Image.Image: if img.size != TARGET_SIZE: return img.resize(TARGET_SIZE, Image.BICUBIC) return img def load_sample_images(): paths = [ hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/rococo.jpg", repo_type="space"), hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/astronaut.png", repo_type="space"), hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/cat.png", repo_type="space"), ] return [_ensure_size(Image.open(p).convert("RGB")) for p in paths] def build_index(images): global INDEX_IMAGES, INDEX_EMB processed = [_ensure_size(img.convert("RGB")) for img in images] INDEX_IMAGES = processed with torch.inference_mode(): inputs = processor.process_images(processed) inputs.to(device) emb = model(**inputs) INDEX_EMB = torch.nn.functional.normalize(emb, dim=-1) return f"Indexed {len(processed)} images (resized to {TARGET_SIZE[0]}x{TARGET_SIZE[1]})" def ensure_index(): if not INDEX_IMAGES: # Auto-load 1000 images from ImageNet-1K dataset print("Auto-loading 1000 images from ImageNet-1K dataset (this may take a few minutes)...") builder_status = build_index_from_dataset("imagenet-1k", "validation", "image", 1000, 64) print(f"Auto-indexing completed: {builder_status}") return builder_status def search(query, top_k=3): ensure_index() with torch.inference_mode(): q_inputs = processor.process_texts([query]) q_inputs.to(device) q_emb = model(**q_inputs) q_emb = torch.nn.functional.normalize(q_emb, dim=-1) sims = (q_emb @ INDEX_EMB.T).squeeze(0) vals, idxs = torch.topk(sims, k=min(top_k, len(INDEX_IMAGES))) results = [(INDEX_IMAGES[i], f"score={vals[j].item():.4f}") for j, i in enumerate(idxs.tolist())] return results def upload_and_build(files): if not files: return "No files uploaded" images = [_ensure_size(Image.open(f.name).convert("RGB")) for f in files] return build_index(images) def visualize_attention(text_embed, img_embeds, attention_mask=None): """Visualize attention between text and image embeddings""" # Normalize embeddings text_norm = F.normalize(text_embed, dim=-1) img_norm = F.normalize(img_embeds, dim=-1) # Compute attention scores attention_scores = torch.matmul(text_norm, img_norm.transpose(-2, -1)) # Create attention heatmap scores = attention_scores.squeeze().detach().cpu().numpy() fig, ax = plt.subplots(figsize=(10, 6)) im = ax.imshow(scores, cmap='Yl_orange', aspect='auto') ax.set_title('Text-Image Attention Map') ax.set_xlabel('Image Embeddings') ax.set_ylabel('Text Embeddings') # Add colorbar plt.colorbar(im, ax=ax) plt.tight_layout() # Convert to base64 for Gradio buf = BytesIO() fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) img_str = base64.b64encode(buf.getvalue()).decode() plt.close(fig) return f"data:image/png;base64,{img_str}" def test_text_image_alignment(text_inputs, image_files, comparison_text=""): """Test alignment between uploaded text and images with real-time comparison""" if len(image_files) < 2: return "โŒ At least 2 images required for comparison", None, "Upload 2+ images to compare" if not text_inputs.strip(): return "โŒ Text input required", None, "Enter text to test alignment" try: # Process uploaded images images = [] for f in image_files: img = Image.open(f.name).convert("RGB") img = _ensure_size(img) images.append(img) with torch.inference_mode(): # Text embedding text_processed = processor.process_texts([text_inputs]) text_processed.to(device) text_embed = model(**text_processed) text_embed = F.normalize(text_embed, dim=-1) # Image embeddings img_processed = processor.process_images(images) img_processed.to(device) img_embeds = model(**img_processed) img_embeds = F.normalize(img_embeds, dim=-1) # Compute similarities similarities = F.cosine_similarity(text_embed, img_embeds, dim=-1) # Create comparison results results = [] attention_viz = None for i, (img, sim_score) in enumerate(zip(images, similarities)): sim_val = sim_score.item() caption = f"Similarity: {sim_val:.4f}" # Score interpretation if sim_val > 0.7: interpretation = "๐ŸŸข Strong match" elif sim_val > 0.4: interpretation = "๐ŸŸก Moderate match" else: interpretation = "๐Ÿ”ด Weak match" results.append((img, f"{caption} - {interpretation}")) # Generate attention visualization if len(results) >= 2: attention_viz = visualize_attention(text_embed, img_embeds) # Detailed analysis analysis = f""" **Real-time Testing Results:** ๐Ÿ“ **Query Text:** "{text_inputs}" ๐Ÿ–ผ๏ธ **Images Tested:** {len(images)} **Similarity Scores:** """ for i, sim_val in enumerate(similarities): analysis += f"- Image {i+1}: {sim_val:.4f}\n" analysis += f""" **Best Match:** Image #{torch.argmax(similarities).item() + 1} (score: {similarities.max():.4f}) **Average Score:** {similarities.mean():.4f} **Score Range:** {similarities.min():.4f} - {similarities.max():.4f} **Model Training Evidence:** โœ… Text understanding: Model processes natural language โœ… Image understanding: Model processes visual content โœ… Cross-modal alignment: Computes meaningful similarities โœ… Attention mechanism: Learns text-image relationships """ return analysis, results, attention_viz except Exception as e: return f"โŒ Error during testing: {str(e)}", None, None def _preprocess_image_worker(args): """Worker function for preprocessing images in parallel""" row_data = args if isinstance(row_data, tuple): row, image_col, index = row_data else: # Handle direct image data row, image_col = args index = 0 if image_col not in row or row[image_col] is None: return None, index img = row[image_col] if hasattr(img, "convert"): img = img.convert("RGB") img = _ensure_size(img) return img, index def build_index_from_dataset(repo_id: str, split: str = "train", image_col: str = "image", limit: int = 500, batch_size: int = 64): global INDEX_IMAGES, INDEX_EMB ds = load_dataset(repo_id, split=split, streaming=True) # Step 1: Collect images in parallel print(f"Loading and preprocessing {limit} images using {NUM_WORKERS} workers...") image_data = [] count = 0 # Collect raw data first for row in ds: if image_col not in row or row[image_col] is None: continue image_data.append((row, image_col, count)) count += 1 if len(image_data) >= limit: break # Preprocess images in parallel with mp.Pool(NUM_WORKERS) as pool: results = list(tqdm.tqdm( pool.imap(_preprocess_image_worker, image_data), total=len(image_data), desc="Preprocessing images" )) # Filter out None results and sort by index valid_results = [(img, idx) for img, idx in results if img is not None] valid_results.sort(key=lambda x: x[1]) # Sort by original index images = [img for img, _ in valid_results] print(f"Successfully preprocessed {len(images)} images") # Step 2: Embed images in batches (GPU intensive, keep single-threaded) print("Computing embeddings...") all_emb = [] with torch.inference_mode(): for i in tqdm.tqdm(range(0, len(images), batch_size), desc="Computing embeddings"): batch = images[i:i+batch_size] if not batch: continue inputs = processor.process_images(batch) inputs.to(device) emb = model(**inputs) all_emb.append(torch.nn.functional.normalize(emb, dim=-1).to("cpu")) INDEX_IMAGES = images INDEX_EMB = torch.cat(all_emb, dim=0).to(device) return f"Indexed {len(images)} images from {repo_id}:{split} (resized to {TARGET_SIZE[0]}x{TARGET_SIZE[1]}) - Used {NUM_WORKERS} workers" with gr.Blocks(theme='default') as demo: with gr.Tabs(): # Tab 1: Image Search with gr.Tab("๐Ÿ–ผ๏ธ Image Search"): gr.Markdown("# ColModernVBert Image Search") gr.Markdown("โš ๏ธ **First load takes ~2-3 minutes**: Auto-indexing 1000 images from ImageNet-1K validation set") with gr.Row(): with gr.Column(): query = gr.Textbox(label="Text query", value="a baroque painting") topk = gr.Slider(1, 8, value=3, step=1, label="Top-K") btn = gr.Button("Search") out = gr.Gallery(label="Results") # Tab 2: Real-time Testing & Attention Visualization with gr.Tab("๐Ÿงช Model Testing"): gr.Markdown("# Real-time Text-Image Alignment Testing") gr.Markdown("Upload **minimum 2 images** and test with text queries to analyze model behavior") with gr.Row(): with gr.Column(): test_text = gr.Textbox( label="Test Query Text", placeholder="Enter text like 'red car', 'dog playing', 'modern architecture'", value="red sports car" ) test_images = gr.File( file_count="multiple", file_types=["image"], label="Upload Images (Min 2 required)" ) test_btn = gr.Button("๐Ÿง  Test Model Alignment", variant="primary") with gr.Column(): attention_viz = gr.Image(label="Attention Heatmap", type="pil") with gr.Row(): test_results = gr.Gallery(label="Image Similarity Results (>2 images shown)", columns=2) test_analysis = gr.Markdown(label="Detailed Analysis") test_btn.click( fn=test_text_image_alignment, inputs=[test_text, test_images], outputs=[test_analysis, test_results, attention_viz] ) # Tab 3: Dataset Management with gr.Tab("๐Ÿ“š Dataset Management"): gr.Markdown("# Manage Image Index") with gr.Row(): with gr.Column(): up = gr.File(file_count="multiple", type="filepath", label="Upload images to index") status = gr.Textbox(label="Index status", interactive=False) build = gr.Button("Build Index") with gr.Accordion("Load from HF dataset", open=True): repo = gr.Textbox(label="Dataset repo_id", value="imagenet-1k") split = gr.Textbox(label="Split", value="validation") img_col = gr.Textbox(label="Image column", value="image") lim = gr.Number(label="Max images", value=1000, precision=0) bsize = gr.Number(label="Batch size", value=64, precision=0) build_ds = gr.Button("Build Index from Dataset") status_ds = gr.Textbox(label="Index status", interactive=False) # Event handlers btn.click(fn=search, inputs=[query, topk], outputs=out) build.click(fn=upload_and_build, inputs=[up], outputs=status) build_ds.click(lambda r,s,c,l,b: build_index_from_dataset(r, s, c, int(l), int(b)), inputs=[repo, split, img_col, lim, bsize], outputs=status_ds) if __name__ == "__main__": # Start indexing in background (if None, UI still starts; indexing happens on first search) status_msg = ensure_index() demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))