Spaces:
Runtime error
Runtime error
| import joblib | |
| from sentence_transformers import CrossEncoder, SentenceTransformer | |
| import streamlit as st | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from samples import get_samples | |
| import textdistance | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from encode_sentences import encode_sentences | |
| model_save_path = 'trained_model_stsbenchmark_bert-base-uncased' | |
| bi_encoder = 'Bi-Encoder' | |
| cross_encoder = 'Cross-Encoder' | |
| levenshtein_distance = 'Levenshtein Distance' | |
| tf_idf = 'TF-IDF' | |
| random_forest = 'Random Forest' | |
| title = 'Sentence Similarity with Transformers' | |
| st.set_page_config(page_title=title, layout='wide', initial_sidebar_state='auto') | |
| def cache_variables(): | |
| tfidf_vectorizer = TfidfVectorizer() | |
| cross_encoder_trasformer = CrossEncoder(model_save_path) | |
| bi_encoder_trasformer = SentenceTransformer(model_save_path) | |
| random_forest_model = joblib.load('trained_model_random_forest.joblib') | |
| return tfidf_vectorizer, cross_encoder_trasformer, bi_encoder_trasformer, random_forest_model | |
| def compute_similarity(sentence_1, sentence_2, comparison): | |
| if comparison == bi_encoder: | |
| return cosine_similarity([bi_encoder_trasformer.encode(sentence_1)], [bi_encoder_trasformer.encode(sentence_2)])[0][0] | |
| return cross_encoder_trasformer.predict([sentence_1, sentence_2]) | |
| tfidf_vectorizer, cross_encoder_trasformer, bi_encoder_trasformer, random_forest_model = cache_variables() | |
| st.title(title) | |
| st.write("This app takes two sentences and outputs their similarity score using a fine-tuned transformer model.") | |
| # Example sentences section | |
| test_samples = get_samples() | |
| st.sidebar.header("Example Sentences") | |
| example_1 = st.sidebar.radio( | |
| "Sentence 1", test_samples['sentence1'].values.tolist()) | |
| example_2 = st.sidebar.radio( | |
| "Sentence 2", test_samples['sentence2'].values.tolist()) | |
| # Input fields | |
| sentence_1 = st.text_input("Enter Sentence 1:", example_1) | |
| sentence_2 = st.text_input("Enter Sentence 2:", example_2) | |
| comparison = st.selectbox("Comparicon:", [ | |
| bi_encoder, cross_encoder, levenshtein_distance, tf_idf, random_forest]) | |
| if st.button("Compare"): | |
| # Compute similarity | |
| if comparison in [bi_encoder, cross_encoder]: | |
| similarity = compute_similarity(sentence_1, sentence_2, comparison) | |
| elif comparison == levenshtein_distance: | |
| similarity = textdistance.levenshtein.normalized_similarity( | |
| sentence_1, sentence_2) | |
| elif comparison == tf_idf: | |
| similarity = cosine_similarity( | |
| tfidf_vectorizer.fit_transform([sentence_1, sentence_2]))[0][1] | |
| elif comparison == random_forest: | |
| similarity = random_forest_model.predict(encode_sentences( | |
| bi_encoder_trasformer, sentence_1, sentence_2))[0] | |
| st.markdown( | |
| f"<b style='font-size: 1.5em'>{comparison}</b> similarity score: <b style='font-size: 1.5em'>:red[{similarity:.4f}]</b>", unsafe_allow_html=True) | |
| st.write( | |
| "A higher score indicates greater similarity. The score ranges from 0 to 1.") | |