Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| from pipelines.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline | |
| from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline | |
| import orjson | |
| from annotated_text.util import get_annotated_html | |
| from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode | |
| import re | |
| import numpy as np | |
| if "config" not in st.session_state: | |
| with open("config.json", "r") as f: | |
| content = f.read() | |
| st.session_state.config = orjson.loads(content) | |
| st.session_state.data_frame = pd.DataFrame(columns=["model"]) | |
| st.session_state.keyphrases = [] | |
| st.set_page_config( | |
| page_icon="π", | |
| page_title="Keyphrase extraction/generation with Transformers", | |
| layout="wide", | |
| ) | |
| if "select_rows" not in st.session_state: | |
| st.session_state.selected_rows = [] | |
| st.header("π Keyphrase extraction/generation with Transformers") | |
| col1, col2 = st.empty().columns(2) | |
| def load_pipeline(chosen_model): | |
| if "keyphrase-extraction" in chosen_model: | |
| return KeyphraseExtractionPipeline(chosen_model) | |
| elif "keyphrase-generation" in chosen_model: | |
| return KeyphraseGenerationPipeline(chosen_model) | |
| def extract_keyphrases(): | |
| st.session_state.keyphrases = pipe(st.session_state.input_text) | |
| st.session_state.data_frame = pd.concat( | |
| [ | |
| st.session_state.data_frame, | |
| pd.DataFrame( | |
| data=[ | |
| np.concatenate( | |
| ( | |
| [ | |
| st.session_state.chosen_model, | |
| st.session_state.input_text, | |
| ], | |
| st.session_state.keyphrases, | |
| ) | |
| ) | |
| ], | |
| columns=["model", "text"] | |
| + [str(i) for i in range(len(st.session_state.keyphrases))], | |
| ), | |
| ], | |
| ignore_index=True, | |
| axis=0, | |
| ).fillna("") | |
| def get_annotated_text(text, keyphrases): | |
| for keyphrase in keyphrases: | |
| text = re.sub( | |
| f"({keyphrase})", | |
| keyphrase.replace(" ", "$K"), | |
| text, | |
| flags=re.I, | |
| ) | |
| result = [] | |
| for i, word in enumerate(text.split(" ")): | |
| if re.sub(r"[^\w\s]", "", word) in keyphrases: | |
| result.append((word, "KEY", "#21c354")) | |
| elif "$K" in word: | |
| result.append((" ".join(word.split("$K")), "KEY", "#21c354")) | |
| else: | |
| if i == len(st.session_state.input_text.split(" ")) - 1: | |
| result.append(f" {word}") | |
| elif i == 0: | |
| result.append(f"{word} ") | |
| else: | |
| result.append(f" {word} ") | |
| return result | |
| def rerender_output(layout): | |
| layout.subheader("π§ Output") | |
| if ( | |
| len(st.session_state.keyphrases) > 0 | |
| and len(st.session_state.selected_rows) == 0 | |
| ): | |
| text, keyphrases = st.session_state.input_text, st.session_state.keyphrases | |
| else: | |
| text, keyphrases = ( | |
| st.session_state.selected_rows["text"].values[0], | |
| [ | |
| keyphrase | |
| for keyphrase in st.session_state.selected_rows.loc[ | |
| :, | |
| st.session_state.selected_rows.columns.difference( | |
| ["model", "text"] | |
| ), | |
| ] | |
| .astype(str) | |
| .values.tolist()[0] | |
| if keyphrase != "" | |
| ], | |
| ) | |
| result = get_annotated_text(text, keyphrases) | |
| layout.markdown( | |
| get_annotated_html(*result), | |
| unsafe_allow_html=True, | |
| ) | |
| chosen_model = col1.selectbox( | |
| "Choose your model:", | |
| st.session_state.config.get("models"), | |
| ) | |
| st.session_state.chosen_model = chosen_model | |
| pipe = load_pipeline( | |
| f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}" | |
| ) | |
| st.session_state.input_text = col1.text_area( | |
| "Input", st.session_state.config.get("example_text"), height=300 | |
| ) | |
| pressed = col1.button("Extract", on_click=extract_keyphrases) | |
| if len(st.session_state.data_frame.columns) > 0: | |
| st.subheader("π History") | |
| builder = GridOptionsBuilder.from_dataframe( | |
| st.session_state.data_frame, sortable=False | |
| ) | |
| builder.configure_selection(selection_mode="single", use_checkbox=True) | |
| builder.configure_column("text", hide=True) | |
| go = builder.build() | |
| data = AgGrid( | |
| st.session_state.data_frame, | |
| gridOptions=go, | |
| update_mode=GridUpdateMode.SELECTION_CHANGED, | |
| ) | |
| st.session_state.selected_rows = pd.DataFrame(data["selected_rows"]) | |
| if len(st.session_state.selected_rows) > 0 or len(st.session_state.keyphrases) > 0: | |
| rerender_output(col2) | |