import os import time import gradio as gr import pandas as pd from classifier import classify from statistics import mean HFTOKEN = os.environ["HF_TOKEN"] def load_and_analyze_csv(file, text_field, event_model): df = pd.read_table(file.name) if text_field not in df.columns: raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.") floods, fires, nones, scores = [], [], [], [] for post in df[text_field].to_list(): res = classify(post, event_model, HFTOKEN) if res["event"] == 'flood': floods.append(post) elif res["event"] == 'fire': fires.append(post) else: nones.append(post) scores.append(res["score"]) model_confidence = round(mean(scores), 5) fire_related = gr.CheckboxGroup(choices=fires) flood_related = gr.CheckboxGroup(choices=floods) not_related = gr.CheckboxGroup(choices=nones) return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()) def analyze_selected_texts(selections): selected_texts = selections analysis_results = [f"Word Count: {len(text.split())}" for text in selected_texts] result_df = pd.DataFrame({"Selected Text": selected_texts, "Analysis": analysis_results}) return result_df def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts): incorrect = len(flood_selections) + len(fire_selections) + len(none_selections) correct = num_posts - incorrect accuracy = (correct/num_posts)*100 return incorrect, correct, accuracy with gr.Blocks() as demo: event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"] with gr.Tab("Event Type Classification"): gr.Markdown( """ # T4.5 Relevance Classifier Demo This is a demo created to explore floods and wildfire classification in social media posts.\n Usage:\n -Upload .tsv data file (must contain a text column with social media posts).\n -Next, type the name of the text column.\n -Then, choose a BERT classifier model from the drop down.\n -Finally, click the 'start prediction' buttton.\n Evaluation:\n -To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n -Then, click on the 'Calculate Accuracy' button.\n -Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file. """) with gr.Row(equal_height=True): with gr.Column(scale=4): file_input = gr.File(label="Upload CSV File") with gr.Column(scale=6): text_field = gr.Textbox(label="Text field name", value="tweet_text") event_model = gr.Dropdown(event_models, label="Select classification model") predict_button = gr.Button("Start Prediction") with gr.Row(): # XXX confirm this is not a problem later --equal_height=True with gr.Column(): gr.Markdown("""### Flood-related""") flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) with gr.Column(): gr.Markdown("""### Fire-related""") fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) with gr.Column(): gr.Markdown("""### None""") none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) with gr.Row(equal_height=True): with gr.Column(scale=5): gr.Markdown(r""" Accuracy: is the model's ability to make correct predicitons. It is the fraction of correct prediction out of the total predictions. $$ \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100 $$ Model Confidence: is the mean probabilty of each case belonging to their assigned classes. A value of 1 is best. """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]) gr.Markdown("\n\n\n") model_confidence = gr.Number(label="Model Confidence") with gr.Column(scale=5): correct = gr.Number(label="Number of correct classifications", value=0) incorrect = gr.Number(label="Number of incorrect classifications", value=0) accuracy = gr.Number(label="Model Accuracy", value=0) accuracy_button = gr.Button("Calculate Accuracy") num_posts = gr.Number(visible=False) predict_button.click(load_and_analyze_csv, inputs=[file_input, text_field, event_model], outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts]) accuracy_button.click(calculate_accuracy, inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts], outputs=[incorrect, correct, accuracy]) with gr.Tab("Question Answering"): # XXX Add some button disabling here, if the classification process is not completed first XXX analysis_button = gr.Button("Analyze Selected Texts") analysis_output = gr.DataFrame(headers=["Selected Text", "Analysis"]) analysis_button.click(analyze_selected_texts, inputs=flood_checkbox_output, outputs=analysis_output) demo.launch()