| import time | |
| import json | |
| from pyserini.search.lucene import LuceneImpactSearcher | |
| import streamlit as st | |
| from pathlib import Path | |
| import sys | |
| path_root = Path("./") | |
| sys.path.append(str(path_root)) | |
| encoder_index_map = { | |
| 'uniCOIL': ('UniCoil', 'index-unicoil'), | |
| 'SPLADE++ Ensemble Distil': ('SpladePlusPlusEnsembleDistil', 'index-splade-pp-ed'), | |
| 'SPLADE++ Self Distil': ('SpladePlusPlusSelfDistil', 'index-splade-pp-sd') | |
| } | |
| index = 'index-splade-pp-ed' | |
| encoder = 'SpladePlusPlusEnsembleDistil' | |
| st.set_page_config(page_title="Pyserini with ONNX Runtime", | |
| page_icon='πΈ', layout="centered") | |
| cola, colb, colc = st.columns([5, 4, 5]) | |
| with colb: | |
| st.image("logo.jpeg") | |
| colaa, colbb, colcc = st.columns([1, 8, 1]) | |
| with colbb: | |
| encoder = st.select_slider( | |
| 'Select a query encoder with ONNX Runtime', | |
| options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil']) | |
| st.write('Now Running Encoder: ', encoder) | |
| encoder, index = encoder_index_map[encoder] | |
| col1, col2 = st.columns([9, 1]) | |
| with col1: | |
| search_query = st.text_input(label="search query", placeholder="Search") | |
| with col2: | |
| st.write('#') | |
| button_clicked = st.button("π") | |
| searcher = LuceneImpactSearcher( | |
| f'indexes/{index}', f'{encoder}', encoder_type='onnx') | |
| if search_query or button_clicked: | |
| num_results = None | |
| t_0 = time.time() | |
| print("search query is:\t", search_query) | |
| search_results = searcher.search(search_query, k=10) | |
| search_time = time.time() - t_0 | |
| st.write( | |
| f'<p align=\"right\" style=\"color:grey;\">Retrieved {len(search_results):,.0f} documents in {search_time*1000:.2f} ms</p>', unsafe_allow_html=True) | |
| for i, result in enumerate(search_results[:10]): | |
| result_score = result.score | |
| result_id = result.docid | |
| output = f'<div class="row"> <div class="column"> <b>Rank</b>: {i+1} </div><div class="column"><b>Document ID</b>: {result_id}</div><div class="column"><b>Score</b>:{result_score:.2f}</div></div>' | |
| try: | |
| st.write(output, unsafe_allow_html=True) | |
| except: | |
| pass | |
| st.write('---') | |