Spaces:
Running
Running
| import torch | |
| import pandas as pd | |
| import configparser | |
| import gradio as gr | |
| from gensim.models import KeyedVectors | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForTokenClassification, AutoTokenizer | |
| from segmentation import segment | |
| from utils import clean_entity | |
| class Linker: | |
| def __init__(self, config: dict[str, object], | |
| context_window_width: int = -1): | |
| self._vectors = None | |
| self._emb_model = None | |
| if context_window_width <= 0: | |
| context_window_width = config['context_window_width'] | |
| self.context_window_width = context_window_width | |
| self.config = config | |
| def add_context(self, row: pd.Series) -> str: | |
| window_start = max(0, row.start - self.context_window_width) | |
| window_end = min(row.end + self.context_window_width, len(row.text)) | |
| return clean_entity(row.text[window_start:window_end]) | |
| def _load_embeddings(self): | |
| self._vectors = KeyedVectors.load(self.config['keyed_vectors_file']) | |
| def _load_model(self): | |
| self._emb_model = SentenceTransformer(config['embedding_model']) | |
| def embeddings(self): | |
| if self._vectors is None: | |
| self._load_embeddings() | |
| return self._vectors | |
| def embedding_model(self): | |
| if self._emb_model is None: | |
| self._load_model() | |
| return self._emb_model | |
| def link(self, df: pd.DataFrame) -> list[dict]: | |
| mention_emb = self.embedding_model.encode(df.mention.str.lower().values) | |
| concepts = [self.embeddings.most_similar(m, topn=1)[0][0] | |
| for m in mention_emb] | |
| return concepts | |
| def highlight_text(spans: pd.DataFrame, text: str) -> list[tuple[str, object]]: | |
| token_concepts = [None for _ in text] | |
| for row in spans.itertuples(): | |
| for k in range(row.start, row.end): | |
| token_concepts[k] = row.concept | |
| return list(zip(list(text), token_concepts)) | |
| def entity_link(query: str) -> list[tuple[str, object]]: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| seg_model = AutoModelForTokenClassification.from_pretrained( | |
| config['segmentation_model'] | |
| ) | |
| seg_tokenizer = AutoTokenizer.from_pretrained( | |
| config['segmentation_tokenizer'] | |
| ) | |
| thresh = float(config['thresh']) | |
| query_df = pd.DataFrame({'note_id': [0], 'text': [query]}) | |
| seg = segment(query_df, seg_model, seg_tokenizer, device, thresh) | |
| linked_concepts = [] | |
| if len(seg) > 0: | |
| seg = seg.sort_values('start') | |
| linked_concepts = linker.link(seg) | |
| seg['concept'] = linked_concepts | |
| return highlight_text(seg, query) | |
| config_parser = configparser.ConfigParser() | |
| config_parser.read('config.ini') | |
| config = config_parser['DEFAULT'] | |
| linker = Linker(config) | |
| demo = gr.Interface( | |
| fn=entity_link, | |
| inputs=["text"], | |
| outputs=gr.HighlightedText( | |
| label="linking", | |
| combine_adjacent=True, | |
| ), | |
| theme=gr.themes.Base() | |
| ) | |
| demo.launch() | |