Spaces:
Sleeping
Sleeping
| import argparse | |
| import json as js | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import fasttext | |
| import gradio as gr | |
| import joblib | |
| import omikuji | |
| from huggingface_hub import snapshot_download | |
| from prepare_everything import download_model | |
| download_model( | |
| "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin", | |
| Path("lid.176.bin")) | |
| # Download the model files from Hugging Face | |
| model_names = [ | |
| "omikuji-bonsai-parliament-spacy-de-all_topics-input_long", | |
| "omikuji-bonsai-parliament-spacy-fr-all_topics-input_long", | |
| "omikuji-bonsai-parliament-spacy-it-all_topics-input_long", | |
| ] | |
| for repo_id in model_names: | |
| if not os.path.exists(repo_id): | |
| os.makedirs(repo_id) | |
| model_dir = snapshot_download(repo_id=f"kapllan/{repo_id}", local_dir=f"kapllan/{repo_id}") | |
| lang_model = fasttext.load_model("lid.176.bin") | |
| with open(Path("label2id.json"), "r") as f: | |
| label2id = js.load(f) | |
| id2label = {} | |
| for key, value in label2id.items(): | |
| id2label[str(value)] = key | |
| with open(Path("topics_hierarchy.json"), "r") as f: | |
| topics_hierarchy = js.load(f) | |
| def map_language(language: str) -> str: | |
| language_mapping = {"de": "German", "it": "Italian", "fr": "French"} | |
| if language in language_mapping.keys(): | |
| return language_mapping[language] | |
| else: | |
| return language | |
| def find_model(language: str): | |
| vectorizer, model = None, None | |
| if language in ["de", "fr", "it"]: | |
| path_to_vectorizer = ( | |
| f"./kapllan/omikuji-bonsai-parliament-spacy-{language}-all_topics-input_long/vectorizer" | |
| ) | |
| path_to_model = ( | |
| f"./kapllan/omikuji-bonsai-parliament-spacy-{language}-all_topics-input_long/omikuji-model" | |
| ) | |
| vectorizer = joblib.load(path_to_vectorizer) | |
| model = omikuji.Model.load(path_to_model) | |
| return vectorizer, model | |
| def predict_lang(text: str) -> str: | |
| text = re.sub( | |
| r"\n", "", text | |
| ) # Remove linebreaks because fasttext cannot process that otherwise | |
| predictions = lang_model.predict(text, k=1) # returns top 2 matching languages | |
| language = predictions[0][0] # returns top 2 matching languages | |
| language = re.sub(r"__label__", "", language) # returns top 2 matching languages | |
| return language | |
| def predict_topic(text: str) -> [List[str], str]: | |
| results = [] | |
| language = predict_lang(text) | |
| vectorizer, model = find_model(language) | |
| language = map_language(language) | |
| if vectorizer is not None: | |
| texts = [text] | |
| vector = vectorizer.transform(texts) | |
| for row in vector: | |
| if row.nnz == 0: # All zero vector, empty result | |
| continue | |
| feature_values = [(col, row[0, col]) for col in row.nonzero()[1]] | |
| for subj_id, score in model.predict(feature_values, top_k=1000): | |
| score = round(score, 2) | |
| results.append((id2label[str(subj_id)], score)) | |
| return results, language | |
| def get_row_color(type: str): | |
| if "main" in type.lower(): | |
| return "background-color: darkgrey;" | |
| if "sub" in type.lower(): | |
| return "background-color: lightgrey;" | |
| def generate_html_table(topics: List[Tuple[str, str, float]]): | |
| html = '<table style="width:100%; border: 1px solid black; border-collapse: collapse;">' | |
| html += "<tr><th>Type</th><th>Topic</th><th>Score</th></tr>" | |
| for type, topic, score in topics: | |
| color = get_row_color(type) | |
| topic = f"<strong>{topic}</strong>" if "main" in type.lower() else topic | |
| type = f"<strong>{type}</strong>" if "main" in type.lower() else type | |
| score = f"<strong>{score}</strong>" if "main" in type.lower() else score | |
| html += ( | |
| f'<tr style="{color}"><td>{type}</td><td>{topic}</td><td>{score}</td></tr>' | |
| ) | |
| html += "</table>" | |
| return html | |
| def restructure_topics(topics: List[Tuple[str, float]]) -> List[Tuple[str, str, float]]: | |
| topics = [(str(x[0]).lower(), x[1]) for x in topics] | |
| topics_as_dict = {} | |
| for predicted_topic, score in topics: | |
| if str(predicted_topic).lower() in topics_hierarchy.keys(): | |
| topics_as_dict[str(predicted_topic).lower()] = [] | |
| for predicted_topic, score in topics: | |
| for main_topic, sub_topics in topics_hierarchy.items(): | |
| if ( | |
| main_topic in topics_as_dict.keys() | |
| and predicted_topic != main_topic | |
| and predicted_topic in sub_topics | |
| ): | |
| topics_as_dict[main_topic].append(predicted_topic) | |
| topics_restructured = [] | |
| for predicted_main_topic, predicted_sub_topics in topics_as_dict.items(): | |
| if len(predicted_sub_topics) > 0: | |
| score = [t for t in topics if t[0] == predicted_main_topic][0][1] | |
| predicted_main_topic = predicted_main_topic.replace("hauptthema: ", "") | |
| topics_restructured.append(("Main Topic", predicted_main_topic, score)) | |
| predicted_sub_topics_with_scores = [] | |
| for pst in predicted_sub_topics: | |
| score = [t for t in topics if t[0] == pst][0][1] | |
| pst = pst.replace("unterthema: ", "") | |
| entry = ("Sub Topic", pst, score) | |
| if entry not in predicted_sub_topics_with_scores: | |
| predicted_sub_topics_with_scores.append(entry) | |
| for x in predicted_sub_topics_with_scores: | |
| topics_restructured.append(x) | |
| return topics_restructured | |
| def topic_modeling(text: str, threshold: float) -> [List[str], str]: | |
| # Prepare labels and scores for the plot | |
| sorted_topics, language = predict_topic(text) | |
| if len(sorted_topics) > 0 and language in ["German", "French", "Italian"]: | |
| sorted_topics = [t for t in sorted_topics if t[1] >= threshold] | |
| else: | |
| sorted_topics = [] | |
| sorted_topics = restructure_topics(sorted_topics) | |
| sorted_topics = generate_html_table(sorted_topics) | |
| return sorted_topics, language | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Topic Modeling") | |
| gr.Markdown("Enter a document and get each topic along with its score.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(lines=10, placeholder="Enter a document") | |
| submit_button = gr.Button("Submit") | |
| threshold_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, step=0.01, label="Score Threshold", value=0.0 | |
| ) | |
| language_text = gr.Textbox( | |
| lines=1, | |
| placeholder="Detected language will be shown here...", | |
| interactive=False, | |
| label="Detected Language", | |
| ) | |
| with gr.Column(): | |
| output_data = gr.HTML() | |
| submit_button.click( | |
| topic_modeling, | |
| inputs=[input_text, threshold_slider], | |
| outputs=[output_data, language_text], | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-ipa", | |
| "--ip_address", | |
| default=None, | |
| type=str, | |
| help="Specify the IP address of your computer.", | |
| ) | |
| args = parser.parse_args() | |
| # Launch the app | |
| if args.ip_address is None: | |
| _, public_url = iface.launch(share=True) | |
| print(f"The app runs here: {public_url}") | |
| else: | |
| iface.launch(server_name=args.ip_address, server_port=8080, show_error=True) | |