Spaces:
Build error
Build error
| import matplotlib.pyplot as plt | |
| import nmslib | |
| import numpy as np | |
| import os | |
| import streamlit as st | |
| from PIL import Image | |
| from transformers import CLIPProcessor, FlaxCLIPModel | |
| import utils | |
| BASELINE_MODEL = "openai/clip-vit-base-patch32" | |
| MODEL_PATH = "flax-community/clip-rsicd-v2" | |
| IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
| IMAGES_DIR = "./images" | |
| CAPTIONS_FILE = os.path.join(IMAGES_DIR, "test-captions.json") | |
| def app(): | |
| filenames, index = utils.load_index(IMAGE_VECTOR_FILE) | |
| model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL) | |
| image2caption = utils.load_captions(CAPTIONS_FILE) | |
| st.title("Retrieve Images given Text") | |
| st.markdown(""" | |
| This demo shows the image to text retrieval capabilities of this model, i.e., | |
| given a text query, we use our fine-tuned CLIP model to project the text query | |
| to the image/caption embedding space and search for nearby images (by | |
| cosine similarity) in this space. | |
| Our fine-tuned CLIP model was previously used to generate image vectors for | |
| our demo, and NMSLib was used for fast vector access. | |
| """) | |
| suggested_query = [ | |
| "ships", | |
| "school house", | |
| "military installation", | |
| "mountains", | |
| "beaches", | |
| "airports", | |
| "lakes" | |
| ] | |
| st.text("Some suggested queries to start you off with...") | |
| col0, col1, col2, col3, col4, col5, col6 = st.columns(7) | |
| # [1, 1.1, 1.3, 1.1, 1, 1, 1]) | |
| suggest_idx = -1 | |
| with col0: | |
| if st.button(suggested_query[0]): | |
| suggest_idx = 0 | |
| with col1: | |
| if st.button(suggested_query[1]): | |
| suggest_idx = 1 | |
| with col2: | |
| if st.button(suggested_query[2]): | |
| suggest_idx = 2 | |
| with col3: | |
| if st.button(suggested_query[3]): | |
| suggest_idx = 3 | |
| with col4: | |
| if st.button(suggested_query[4]): | |
| suggest_idx = 4 | |
| with col5: | |
| if st.button(suggested_query[5]): | |
| suggest_idx = 5 | |
| with col6: | |
| if st.button(suggested_query[6]): | |
| suggest_idx = 6 | |
| query = st.text_input("OR enter a text Query:") | |
| query = suggested_query[suggest_idx] if suggest_idx > -1 else query | |
| if st.button("Query") or suggest_idx > -1: | |
| inputs = processor(text=[query], images=None, return_tensors="jax", padding=True) | |
| query_vec = model.get_text_features(**inputs) | |
| query_vec = np.asarray(query_vec) | |
| ids, distances = index.knnQuery(query_vec, k=10) | |
| result_filenames = [filenames[id] for id in ids] | |
| for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)): | |
| caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score) | |
| col1, col2, col3 = st.columns([2, 10, 10]) | |
| col1.markdown("{:d}.".format(rank + 1)) | |
| col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)), | |
| caption=caption) | |
| caption_text = [] | |
| for caption in image2caption[result_filename]: | |
| caption_text.append("* {:s}\n".format(caption)) | |
| col3.markdown("".join(caption_text)) | |
| st.markdown("---") | |
| suggest_idx = -1 | |