Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from colpali_engine.models import ColModernVBert, ColModernVBertProcessor | |
| from colpali_engine.utils.torch_utils import get_torch_device | |
| from datasets import load_dataset | |
| import multiprocessing as mp | |
| from functools import partial | |
| import tqdm | |
| import matplotlib.pyplot as plt | |
| import base64 | |
| from io import BytesIO | |
| import numpy as np | |
| MODEL_ID = "ModernVBERT/colmodernvbert" | |
| device = get_torch_device("auto") | |
| processor = ColModernVBertProcessor.from_pretrained(MODEL_ID) | |
| model = ColModernVBert.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| ) | |
| model.to(device) | |
| model.eval() | |
| INDEX_IMAGES = [] | |
| INDEX_EMB = None | |
| TARGET_SIZE = (512, 512) | |
| NUM_WORKERS = mp.cpu_count() // 2 # Use half the CPU cores to avoid contention | |
| def _ensure_size(img: Image.Image) -> Image.Image: | |
| if img.size != TARGET_SIZE: | |
| return img.resize(TARGET_SIZE, Image.BICUBIC) | |
| return img | |
| def load_sample_images(): | |
| paths = [ | |
| hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/rococo.jpg", repo_type="space"), | |
| hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/astronaut.png", repo_type="space"), | |
| hf_hub_download("HuggingFaceTB/SmolVLM", "example_images/cat.png", repo_type="space"), | |
| ] | |
| return [_ensure_size(Image.open(p).convert("RGB")) for p in paths] | |
| def build_index(images): | |
| global INDEX_IMAGES, INDEX_EMB | |
| processed = [_ensure_size(img.convert("RGB")) for img in images] | |
| INDEX_IMAGES = processed | |
| with torch.inference_mode(): | |
| inputs = processor.process_images(processed) | |
| inputs.to(device) | |
| emb = model(**inputs) | |
| INDEX_EMB = torch.nn.functional.normalize(emb, dim=-1) | |
| return f"Indexed {len(processed)} images (resized to {TARGET_SIZE[0]}x{TARGET_SIZE[1]})" | |
| def ensure_index(): | |
| if not INDEX_IMAGES: | |
| # Auto-load 1000 images from ImageNet-1K dataset | |
| print("Auto-loading 1000 images from ImageNet-1K dataset (this may take a few minutes)...") | |
| builder_status = build_index_from_dataset("imagenet-1k", "validation", "image", 1000, 64) | |
| print(f"Auto-indexing completed: {builder_status}") | |
| return builder_status | |
| def search(query, top_k=3): | |
| ensure_index() | |
| with torch.inference_mode(): | |
| q_inputs = processor.process_texts([query]) | |
| q_inputs.to(device) | |
| q_emb = model(**q_inputs) | |
| q_emb = torch.nn.functional.normalize(q_emb, dim=-1) | |
| sims = (q_emb @ INDEX_EMB.T).squeeze(0) | |
| vals, idxs = torch.topk(sims, k=min(top_k, len(INDEX_IMAGES))) | |
| results = [(INDEX_IMAGES[i], f"score={vals[j].item():.4f}") for j, i in enumerate(idxs.tolist())] | |
| return results | |
| def upload_and_build(files): | |
| if not files: | |
| return "No files uploaded" | |
| images = [_ensure_size(Image.open(f.name).convert("RGB")) for f in files] | |
| return build_index(images) | |
| def visualize_attention(text_embed, img_embeds, attention_mask=None): | |
| """Visualize attention between text and image embeddings""" | |
| # Normalize embeddings | |
| text_norm = F.normalize(text_embed, dim=-1) | |
| img_norm = F.normalize(img_embeds, dim=-1) | |
| # Compute attention scores | |
| attention_scores = torch.matmul(text_norm, img_norm.transpose(-2, -1)) | |
| # Create attention heatmap | |
| scores = attention_scores.squeeze().detach().cpu().numpy() | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| im = ax.imshow(scores, cmap='Yl_orange', aspect='auto') | |
| ax.set_title('Text-Image Attention Map') | |
| ax.set_xlabel('Image Embeddings') | |
| ax.set_ylabel('Text Embeddings') | |
| # Add colorbar | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| # Convert to base64 for Gradio | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| img_str = base64.b64encode(buf.getvalue()).decode() | |
| plt.close(fig) | |
| return f"data:image/png;base64,{img_str}" | |
| def test_text_image_alignment(text_inputs, image_files, comparison_text=""): | |
| """Test alignment between uploaded text and images with real-time comparison""" | |
| if len(image_files) < 2: | |
| return "β At least 2 images required for comparison", None, "Upload 2+ images to compare" | |
| if not text_inputs.strip(): | |
| return "β Text input required", None, "Enter text to test alignment" | |
| try: | |
| # Process uploaded images | |
| images = [] | |
| for f in image_files: | |
| img = Image.open(f.name).convert("RGB") | |
| img = _ensure_size(img) | |
| images.append(img) | |
| with torch.inference_mode(): | |
| # Text embedding | |
| text_processed = processor.process_texts([text_inputs]) | |
| text_processed.to(device) | |
| text_embed = model(**text_processed) | |
| text_embed = F.normalize(text_embed, dim=-1) | |
| # Image embeddings | |
| img_processed = processor.process_images(images) | |
| img_processed.to(device) | |
| img_embeds = model(**img_processed) | |
| img_embeds = F.normalize(img_embeds, dim=-1) | |
| # Compute similarities | |
| similarities = F.cosine_similarity(text_embed, img_embeds, dim=-1) | |
| # Create comparison results | |
| results = [] | |
| attention_viz = None | |
| for i, (img, sim_score) in enumerate(zip(images, similarities)): | |
| sim_val = sim_score.item() | |
| caption = f"Similarity: {sim_val:.4f}" | |
| # Score interpretation | |
| if sim_val > 0.7: | |
| interpretation = "π’ Strong match" | |
| elif sim_val > 0.4: | |
| interpretation = "π‘ Moderate match" | |
| else: | |
| interpretation = "π΄ Weak match" | |
| results.append((img, f"{caption} - {interpretation}")) | |
| # Generate attention visualization | |
| if len(results) >= 2: | |
| attention_viz = visualize_attention(text_embed, img_embeds) | |
| # Detailed analysis | |
| analysis = f""" | |
| **Real-time Testing Results:** | |
| π **Query Text:** "{text_inputs}" | |
| πΌοΈ **Images Tested:** {len(images)} | |
| **Similarity Scores:** | |
| """ | |
| for i, sim_val in enumerate(similarities): | |
| analysis += f"- Image {i+1}: {sim_val:.4f}\n" | |
| analysis += f""" | |
| **Best Match:** Image #{torch.argmax(similarities).item() + 1} (score: {similarities.max():.4f}) | |
| **Average Score:** {similarities.mean():.4f} | |
| **Score Range:** {similarities.min():.4f} - {similarities.max():.4f} | |
| **Model Training Evidence:** | |
| β Text understanding: Model processes natural language | |
| β Image understanding: Model processes visual content | |
| β Cross-modal alignment: Computes meaningful similarities | |
| β Attention mechanism: Learns text-image relationships | |
| """ | |
| return analysis, results, attention_viz | |
| except Exception as e: | |
| return f"β Error during testing: {str(e)}", None, None | |
| def _preprocess_image_worker(args): | |
| """Worker function for preprocessing images in parallel""" | |
| row_data = args | |
| if isinstance(row_data, tuple): | |
| row, image_col, index = row_data | |
| else: | |
| # Handle direct image data | |
| row, image_col = args | |
| index = 0 | |
| if image_col not in row or row[image_col] is None: | |
| return None, index | |
| img = row[image_col] | |
| if hasattr(img, "convert"): | |
| img = img.convert("RGB") | |
| img = _ensure_size(img) | |
| return img, index | |
| def build_index_from_dataset(repo_id: str, split: str = "train", image_col: str = "image", limit: int = 500, batch_size: int = 64): | |
| global INDEX_IMAGES, INDEX_EMB | |
| ds = load_dataset(repo_id, split=split, streaming=True) | |
| # Step 1: Collect images in parallel | |
| print(f"Loading and preprocessing {limit} images using {NUM_WORKERS} workers...") | |
| image_data = [] | |
| count = 0 | |
| # Collect raw data first | |
| for row in ds: | |
| if image_col not in row or row[image_col] is None: | |
| continue | |
| image_data.append((row, image_col, count)) | |
| count += 1 | |
| if len(image_data) >= limit: | |
| break | |
| # Preprocess images in parallel | |
| with mp.Pool(NUM_WORKERS) as pool: | |
| results = list(tqdm.tqdm( | |
| pool.imap(_preprocess_image_worker, image_data), | |
| total=len(image_data), | |
| desc="Preprocessing images" | |
| )) | |
| # Filter out None results and sort by index | |
| valid_results = [(img, idx) for img, idx in results if img is not None] | |
| valid_results.sort(key=lambda x: x[1]) # Sort by original index | |
| images = [img for img, _ in valid_results] | |
| print(f"Successfully preprocessed {len(images)} images") | |
| # Step 2: Embed images in batches (GPU intensive, keep single-threaded) | |
| print("Computing embeddings...") | |
| all_emb = [] | |
| with torch.inference_mode(): | |
| for i in tqdm.tqdm(range(0, len(images), batch_size), desc="Computing embeddings"): | |
| batch = images[i:i+batch_size] | |
| if not batch: | |
| continue | |
| inputs = processor.process_images(batch) | |
| inputs.to(device) | |
| emb = model(**inputs) | |
| all_emb.append(torch.nn.functional.normalize(emb, dim=-1).to("cpu")) | |
| INDEX_IMAGES = images | |
| INDEX_EMB = torch.cat(all_emb, dim=0).to(device) | |
| return f"Indexed {len(images)} images from {repo_id}:{split} (resized to {TARGET_SIZE[0]}x{TARGET_SIZE[1]}) - Used {NUM_WORKERS} workers" | |
| with gr.Blocks(theme='default') as demo: | |
| with gr.Tabs(): | |
| # Tab 1: Image Search | |
| with gr.Tab("πΌοΈ Image Search"): | |
| gr.Markdown("# ColModernVBert Image Search") | |
| gr.Markdown("β οΈ **First load takes ~2-3 minutes**: Auto-indexing 1000 images from ImageNet-1K validation set") | |
| with gr.Row(): | |
| with gr.Column(): | |
| query = gr.Textbox(label="Text query", value="a baroque painting") | |
| topk = gr.Slider(1, 8, value=3, step=1, label="Top-K") | |
| btn = gr.Button("Search") | |
| out = gr.Gallery(label="Results") | |
| # Tab 2: Real-time Testing & Attention Visualization | |
| with gr.Tab("π§ͺ Model Testing"): | |
| gr.Markdown("# Real-time Text-Image Alignment Testing") | |
| gr.Markdown("Upload **minimum 2 images** and test with text queries to analyze model behavior") | |
| with gr.Row(): | |
| with gr.Column(): | |
| test_text = gr.Textbox( | |
| label="Test Query Text", | |
| placeholder="Enter text like 'red car', 'dog playing', 'modern architecture'", | |
| value="red sports car" | |
| ) | |
| test_images = gr.File( | |
| file_count="multiple", | |
| file_types=["image"], | |
| label="Upload Images (Min 2 required)" | |
| ) | |
| test_btn = gr.Button("π§ Test Model Alignment", variant="primary") | |
| with gr.Column(): | |
| attention_viz = gr.Image(label="Attention Heatmap", type="pil") | |
| with gr.Row(): | |
| test_results = gr.Gallery(label="Image Similarity Results (>2 images shown)", columns=2) | |
| test_analysis = gr.Markdown(label="Detailed Analysis") | |
| test_btn.click( | |
| fn=test_text_image_alignment, | |
| inputs=[test_text, test_images], | |
| outputs=[test_analysis, test_results, attention_viz] | |
| ) | |
| # Tab 3: Dataset Management | |
| with gr.Tab("π Dataset Management"): | |
| gr.Markdown("# Manage Image Index") | |
| with gr.Row(): | |
| with gr.Column(): | |
| up = gr.File(file_count="multiple", type="filepath", label="Upload images to index") | |
| status = gr.Textbox(label="Index status", interactive=False) | |
| build = gr.Button("Build Index") | |
| with gr.Accordion("Load from HF dataset", open=True): | |
| repo = gr.Textbox(label="Dataset repo_id", value="imagenet-1k") | |
| split = gr.Textbox(label="Split", value="validation") | |
| img_col = gr.Textbox(label="Image column", value="image") | |
| lim = gr.Number(label="Max images", value=1000, precision=0) | |
| bsize = gr.Number(label="Batch size", value=64, precision=0) | |
| build_ds = gr.Button("Build Index from Dataset") | |
| status_ds = gr.Textbox(label="Index status", interactive=False) | |
| # Event handlers | |
| btn.click(fn=search, inputs=[query, topk], outputs=out) | |
| build.click(fn=upload_and_build, inputs=[up], outputs=status) | |
| build_ds.click(lambda r,s,c,l,b: build_index_from_dataset(r, s, c, int(l), int(b)), inputs=[repo, split, img_col, lim, bsize], outputs=status_ds) | |
| if __name__ == "__main__": | |
| # Start indexing in background (if None, UI still starts; indexing happens on first search) | |
| status_msg = ensure_index() | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) | |