Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import json | |
| from collections import defaultdict | |
| # Create tokenizer for biomed model | |
| from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification | |
| tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
| model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
| pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
| # Matplotlib for entity graph | |
| import matplotlib.pyplot as plt | |
| plt.switch_backend("Agg") | |
| # Load examples from JSON | |
| EXAMPLES = {} | |
| with open("examples.json", "r") as f: | |
| example_json = json.load(f) | |
| EXAMPLES = {x["text"]: x["label"] for x in example_json} | |
| def group_by_entity(raw): | |
| out = defaultdict(int) | |
| for ent in raw: | |
| out[ent["entity_group"]] += 1 | |
| # out["total"] = sum(out.values()) | |
| return out | |
| def plot_to_figure(grouped): | |
| fig = plt.figure() | |
| plt.bar(x=list(grouped.keys()), height=list(grouped.values())) | |
| plt.margins(0.2) | |
| plt.subplots_adjust(bottom=0.4) | |
| plt.xticks(rotation=90) | |
| return fig | |
| def ner(text): | |
| raw = pipe(text) | |
| ner_content = { | |
| "text": text, | |
| "entities": [ | |
| { | |
| "entity": x["entity_group"], | |
| "word": x["word"], | |
| "score": x["score"], | |
| "start": x["start"], | |
| "end": x["end"], | |
| } | |
| for x in raw | |
| ], | |
| } | |
| grouped = group_by_entity(raw) | |
| figure = plot_to_figure(grouped) | |
| label = EXAMPLES.get(text, "Unknown") | |
| meta = { | |
| "entity_counts": grouped, | |
| "entities": len(set(grouped.keys())), | |
| "counts": sum(grouped.values()), | |
| } | |
| return (ner_content, meta, label, figure) | |
| interface = gr.Interface( | |
| ner, | |
| inputs=gr.Textbox(label="Note text", value=""), | |
| outputs=[ | |
| gr.HighlightedText(label="NER", combine_adjacent=True), | |
| gr.JSON(label="Entity Counts"), | |
| gr.Label(label="Rating"), | |
| gr.Plot(label="Bar"), | |
| ], | |
| examples=list(EXAMPLES.keys()), | |
| allow_flagging="never", | |
| ) | |
| interface.launch() |