Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| st.set_page_config(page_title='T2I', page_icon="π§", layout='centered') | |
| st.title("Text To Image Retrieval for KaggleX BPIOC Mentorship Program") | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import faiss | |
| import numpy as np | |
| from PIL import Image | |
| from sentence_transformers import SentenceTransformer | |
| import json | |
| import zipfile | |
| # Map the image ids to the corresponding image URLs | |
| image_map_name = 'captions.json' | |
| with open(image_map_name, 'r') as f: | |
| caption_dict = json.load(f) | |
| image_list = list(caption_dict.keys()) | |
| caption_list = list(caption_dict.values()) | |
| zip_path = "Images.zip" | |
| zip_file = zipfile.ZipFile(zip_path) | |
| model_name = "sentence-transformers/all-distilroberta-v1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = SentenceTransformer(model_name) | |
| # vectors = model.encode(caption_list) | |
| vectors = np.load("./sbert_text_features.npy") | |
| vector_dimension = vectors.shape[1] | |
| index = faiss.IndexFlatIP(vector_dimension) | |
| faiss.normalize_L2(vectors) | |
| index.add(vectors) | |
| def search(query, k=4): | |
| # Encode the query | |
| query_embedding = model.encode(query) | |
| query_vector = np.array([query_embedding]) | |
| faiss.normalize_L2(query_vector) | |
| index.nprobe = index.ntotal | |
| # Search for the nearest neighbors in the FAISS index | |
| D, I = index.search(query_vector, k) | |
| # Map the image ids to the corresponding image URLs | |
| image_urls = [] | |
| for i in I[0]: | |
| text_id = i | |
| image_id = str(image_list[i]) | |
| image_data = zip_file.open("Images/" +image_id) | |
| image = Image.open(image_data) | |
| st.image(image, width=600) | |
| query = st.text_input("Enter your search query here:") | |
| if st.button("Search"): | |
| if query: | |
| search(query) |