Spaces:
Running
Running
| import pandas as pd | |
| import torch | |
| import faiss | |
| import gradio as gr | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| from src.model import ConditionalViT, B16_Params, categories | |
| from src.transform import valid_tf | |
| from src.process_images import process_img, make_img_html | |
| from src.examples import ExamplesHandler | |
| from src.js_loader import JavaScriptLoader | |
| # Load Model | |
| m = ConditionalViT(**B16_Params, n_categories=len(categories)) | |
| m.load_state_dict(torch.load("./artifacts/cat_condvit_b16.pth", map_location="cpu")) | |
| m.eval() | |
| # Load data | |
| index = faiss.read_index("./artifacts/gallery_index.faiss") | |
| gal_imgs = pd.read_parquet("./artifacts/gallery_imgs.parquet") | |
| tfs = valid_tf((224, 224)) | |
| K = 5 | |
| examples = [ | |
| ["examples/3.jpg", "Outwear"], | |
| ["examples/3.jpg", "Lower Body"], | |
| ["examples/3.jpg", "Feet"], | |
| ["examples/757.jpg", "Bags"], | |
| ["examples/757.jpg", "Upper Body"], | |
| ["examples/769.jpg", "Upper Body"], | |
| ["examples/1811.jpg", "Lower Body"], | |
| ["examples/1811.jpg", "Bags"], | |
| ] | |
| def retrieval(image, category): | |
| if image is None or category is None: | |
| return | |
| q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category])) | |
| r = index.search(q_emb, K) | |
| imgs = [process_img(idx, gal_imgs) for idx in r[1][0]] | |
| html = [make_img_html(i) for i in imgs] | |
| html += ["<p></p>"] # Avoid Gradio's last-child{margin-bottom:0!important;} | |
| return "\n".join(html) | |
| JavaScriptLoader("src/custom_functions.js") | |
| with gr.Blocks(css="src/style.css") as demo: | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| # Conditional ViT Demo | |
| [[`Paper`](https://arxiv.org/abs/2306.02928)] | |
| [[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)] | |
| [[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)] | |
| [[`Model`](https://huggingface.co/Slep/CondViT-B16-cat)] | |
| *Running on 2 vCPU, 16Go RAM.* | |
| - **Model :** Categorical CondViT-B/16 | |
| - **Gallery :** 93K images. | |
| """ | |
| ) | |
| # Input section | |
| with gr.Row(): | |
| img = gr.Image(label="Query Image", type="pil", elem_id="query_img") | |
| with gr.Column(): | |
| cat = gr.Dropdown( | |
| choices=categories, | |
| label="Category", | |
| value="Upper Body", | |
| type="index", | |
| elem_id="dropdown", | |
| ) | |
| submit = gr.Button("Submit") | |
| # Examples | |
| gr.Examples( | |
| examples, | |
| inputs=[img, cat], | |
| fn=retrieval, | |
| elem_id="preset_examples", | |
| examples_per_page=100, | |
| ) | |
| gr.HTML( | |
| value=ExamplesHandler(examples).to_html(), | |
| label="examples", | |
| elem_id="html_examples", | |
| ) | |
| # Outputs | |
| gr.Markdown("# Retrieved Items") | |
| out = gr.HTML(label="Results", elem_id="html_output") | |
| submit.click(fn=retrieval, inputs=[img, cat], outputs=out) | |
| demo.launch() | |