import os import time import gradio as gr import pandas as pd from classifier import classify from statistics import mean HFTOKEN = os.environ["HF_TOKEN"] def load_and_classify_csv(file, text_field, event_model): if ".csv" in file.name: df = pd.read_csv(file.name) else ".tsv" in file.name: df = pd.read_table(file.name) if text_field not in df.columns: raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.") labels, scores = [], [] for post in df[text_field].to_list(): res = classify(post, event_model, HFTOKEN) labels.append(res["event"]) scores.append(res["score"]) df["model_label"] = labels df["model_score"] = scores model_confidence = round(mean(scores), 5) fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list()) #fires flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list()) not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list()) return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df def qa_process(selections): selected_texts = selections analysis_results = [f"Word Count: {len(text.split())}" for text in selected_texts] result_df = pd.DataFrame({"Selected Text": selected_texts, "Analysis": analysis_results}) return result_df def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df): posts = data_df[text_field].to_list() selections = flood_selections + fire_selections + none_selections eval = [] for post in posts: if post in selections: eval.append("incorrect") else: eval.append("correct") data_df["model_eval"] = eval incorrect = len(selections) correct = num_posts - incorrect accuracy = (correct/num_posts)*100 data_df.to_csv("output.csv") return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True) def get_queries(): queries = [ "What areas are being evacuated?", "What areas are predicted to be impacted?", "What areas are without power?", "What barriers are hindering response efforts?", "What events have been canceled?", "What preparations are being made?", "What regions have announced a state of emergency?", "What roads are blocked / closed?", "What services have been closed?", "What warnings are currently in effect?", "Where are emergency services deployed?", "Where are emergency services needed?", "Where are evacuations needed?", "Where are people needing rescued?", "Where are recovery efforts taking place?", "Where has building or infrastructure damage occurred?", "Where has flooding occured?" "Where are volunteers being requested?", "Where has road damage occured?", "What area has the wildfire burned?", "Where have homes been damaged or destroyed?"] return gr.CheckboxGroup(choices=queries) with gr.Blocks() as demo: event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"] with gr.Tab("Event Type Classification"): gr.Markdown( """ # T4.5 Relevance Classifier Demo This is a demo created to explore floods and wildfire classification in social media posts.\n Usage:\n (1.) Upload .tsv data file (must contain a text column with social media posts).\n -Next, type the name of the text column.\n -Then, choose a BERT classifier model from the drop down.\n -Finally, click the 'start prediction' buttton.\n Evaluation:\n -To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n -Then, click on the 'Calculate Accuracy' button.\n -Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file. """) with gr.Row(equal_height=True): with gr.Column(scale=4): file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv']) with gr.Column(scale=6): text_field = gr.Textbox(label="Text field name", value="tweet_text") event_model = gr.Dropdown(event_models, label="Select classification model") predict_button = gr.Button("Start Prediction") with gr.Row(): # XXX confirm this is not a problem later --equal_height=True with gr.Column(): gr.Markdown("""### Flood-related""") flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) with gr.Column(): gr.Markdown("""### Fire-related""") fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) with gr.Column(): gr.Markdown("""### None""") none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) with gr.Row(equal_height=True): with gr.Column(scale=5): gr.Markdown(r""" Accuracy: is the model's ability to make correct predicitons. It is the fraction of correct prediction out of the total predictions. $$ \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100 $$ Model Confidence: is the mean probabilty of each case belonging to their assigned classes. A value of 1 is best. """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]) gr.Markdown("\n\n\n") model_confidence = gr.Number(label="Model Confidence") with gr.Column(scale=5): correct = gr.Number(label="Number of correct classifications") incorrect = gr.Number(label="Number of incorrect classifications") accuracy = gr.Number(label="Model Accuracy (%)") accuracy_button = gr.Button("Calculate Accuracy") download_csv = gr.DownloadButton(visible=False) num_posts = gr.Number(visible=False) data = gr.DataFrame(visible=False) data_eval = gr.DataFrame(visible=False) predict_button.click( load_and_classify_csv, inputs=[file_input, text_field, event_model], outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data]) accuracy_button.click( calculate_accuracy, inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data], outputs=[incorrect, correct, accuracy, data_eval, download_csv]) qa_tab = gr.Tab("Question Answering") with qa_tab: # XXX Add some button disabling here, if the classification process is not completed first XXX selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True) qa_tab.select(get_queries, None, selected_queries) qa_button = gr.Button("Start QA") analysis_output = gr.DataFrame(headers=["Selected Text", "Analysis"]) qa_button.click(qa_process, inputs=selected_queries, outputs=analysis_output) # analysis_button = gr.Button("Analyze Selected Texts") # analysis_output = gr.DataFrame(headers=["Selected Text", "Analysis"]) # analysis_button.click(analyze_selected_texts, inputs=flood_checkbox_output, outputs=analysis_output) demo.launch()