Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import string | |
| import random | |
| import gradio as gr | |
| import pandas as pd | |
| from classifier import classify | |
| from statistics import mean | |
| from qa_summary import generate_answer | |
| HFTOKEN = os.environ["HF_TOKEN"] | |
| loadTwitterWidgets_js = """ | |
| async () => { | |
| // Load Twitter Widgets script | |
| const script = document.createElement("script"); | |
| script.onload = () => console.log("Twitter Widgets.js loaded"); | |
| script.src = "https://platform.twitter.com/widgets.js"; | |
| document.head.appendChild(script); | |
| // Define a global function to reload Twitter widgets | |
| globalThis.reloadTwitterWidgets = () => { | |
| // Reload Twitter widgets | |
| if (window.twttr && twttr.widgets) { | |
| twttr.widgets.load(); | |
| } | |
| }; | |
| } | |
| """ | |
| def T_on_select(evt: gr.SelectData): | |
| return evt.value | |
| def single_classification(text, event_model, threshold): | |
| res = classify(text, event_model, HFTOKEN, threshold) | |
| return res["event"], res["score"] | |
| def load_and_classify_csv(file, text_field, event_model, threshold): | |
| text_field = text_field.strip() | |
| filepath = file.name | |
| if ".csv" in filepath: | |
| df = pd.read_csv(filepath) | |
| else: | |
| df = pd.read_table(filepath) | |
| if text_field not in df.columns: | |
| raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.") | |
| labels, scores = [], [] | |
| for post in df[text_field].to_list(): | |
| res = classify(post, event_model, HFTOKEN, threshold) | |
| labels.append(res["event"]) | |
| scores.append(res["score"]) | |
| df["event_label"] = labels | |
| df["model_score"] = scores | |
| # model_confidence = round(mean(scores), 5) | |
| model_confidence = mean(scores) | |
| fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list()) | |
| flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list()) | |
| not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list()) | |
| return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df, gr.update(interactive=True), gr.update(interactive=True) | |
| def load_and_classify_csv_dataframe(file, text_field, event_model, threshold): | |
| text_field = text_field.strip() | |
| filepath = file.name | |
| if ".csv" in filepath: | |
| df = pd.read_csv(filepath) | |
| else: | |
| df = pd.read_table(filepath) | |
| if text_field not in df.columns: | |
| raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.") | |
| labels, scores = [], [] | |
| for post in df[text_field].to_list(): | |
| res = classify(post, event_model, HFTOKEN, threshold) | |
| labels.append(res["event"]) | |
| scores.append(round(res["score"], 5)) | |
| df["event_label"] = labels | |
| df["model_score"] = scores | |
| if 'tweet_id' not in df.columns: | |
| generated_ids = [''.join(random.choices(string.digits, k=10)) for _ in range(len(df))] | |
| df["tweet_id"] = generated_ids | |
| #result_df = df[[text_field, "event_label", "model_score", "tweet_id"]].copy() | |
| result_df = df.copy() | |
| result_df["tweet_id"] = result_df["tweet_id"].astype(str) | |
| filters = list(result_df["event_label"].unique()) | |
| extra_filters = ['Not-'+x for x in filters]+['All'] | |
| return result_df, result_df, gr.update(choices=sorted(filters+extra_filters), | |
| value='All', | |
| label="Filter data by label", | |
| visible=True), gr.update(interactive=True), gr.update(interactive=True) | |
| def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df): | |
| text_field = text_field.strip() | |
| posts = data_df[text_field].to_list() | |
| selections = flood_selections + fire_selections + none_selections | |
| eval = [] | |
| for post in posts: | |
| if post in selections: | |
| eval.append("incorrect") | |
| else: | |
| eval.append("correct") | |
| data_df["model_eval"] = eval | |
| incorrect = len(selections) | |
| correct = num_posts - incorrect | |
| accuracy = (correct/num_posts)*100 | |
| data_df.to_csv("output.csv") | |
| return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True) | |
| def init_queries(history): | |
| history = history or [] | |
| if not history: | |
| history = [ | |
| "What areas are being evacuated?", | |
| "What areas are predicted to be impacted?", | |
| "What areas are without power?", | |
| "What barriers are hindering response efforts?", | |
| "What events have been canceled?", | |
| "What preparations are being made?", | |
| "What regions have announced a state of emergency?", | |
| "What roads are blocked / closed?", | |
| "What services have been closed?", | |
| "What warnings are currently in effect?", | |
| "Where are emergency services deployed?", | |
| "Where are emergency services needed?", | |
| "Where are evacuations needed?", | |
| "Where are people needing rescued?", | |
| "Where are recovery efforts taking place?", | |
| "Where has building or infrastructure damage occurred?", | |
| "Where has flooding occured?" | |
| "Where are volunteers being requested?", | |
| "Where has road damage occured?", | |
| "What area has the wildfire burned?", | |
| "Where have homes been damaged or destroyed?"] | |
| return gr.CheckboxGroup(choices=history), history | |
| def add_query(to_add, history): | |
| if to_add not in history: | |
| history.append(to_add) | |
| return gr.CheckboxGroup(choices=history), history | |
| #def qa_summarise(selected_queries, qa_llm_model, text_field, data_df): | |
| def qa_summarise(selected_queries, qa_llm_model, text_field, response_lang, data_df): | |
| if not selected_queries: | |
| raise gr.Error(f"Error: You have to select one or more queries to ask.") | |
| qa_input_df = data_df[data_df["event_label"] != "none"].reset_index() | |
| texts = qa_input_df[text_field].to_list() | |
| # summary = generate_answer(qa_llm_model, texts, selected_queries[0], selected_queries, mode="multi_summarize") | |
| summary = generate_answer(qa_llm_model, | |
| texts, | |
| selected_queries[0], | |
| selected_queries, | |
| response_lang, | |
| mode="multi_summarize") | |
| doc_df = pd.DataFrame() | |
| doc_df["number"] = [i+1 for i in range(len(texts))] | |
| doc_df["text"] = texts | |
| doc_df["IDs"] = qa_input_df["tweet_id"].to_list() | |
| return summary, doc_df | |
| with gr.Blocks(fill_width=True) as demo: | |
| demo.load(None,None,None,js=loadTwitterWidgets_js) | |
| event_models = ["jayebaku/XLMRoberta-twitter-crexdata-flood-wildfire-detector", | |
| "jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",] | |
| T_data_ss_state = gr.State(value=pd.DataFrame()) | |
| with gr.Tab("Single Text Classification"): | |
| gr.Markdown( | |
| """ | |
| # Single Text Classifier Demo | |
| In this section you test the relevance classifier with written texts.\n | |
| Usage:\n | |
| - Type a tweet-like text in the textbox.\n | |
| - Then press Enter.\n | |
| """) | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| model_sing_classify = gr.Dropdown(event_models, value=event_models[0], label="Select classification model") | |
| with gr.Column(scale=7): | |
| with gr.Accordion("Prediction threshold", open=False): | |
| threshold_sing_classify = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False, | |
| info="This value sets a threshold by which texts classified flood or fire are accepted, \ | |
| higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True) | |
| text_to_classify = gr.Textbox(label="Text", info="Enter tweet-like text", submit_btn=True) | |
| text_to_classify_examples = gr.Examples([["The streets are flooded, I can't leave #BostonStorm"], | |
| ["Controlado el incendio de Rodezno que ha obligado a desalojar a varias bodegas de la zona."], | |
| ["Cambrils:estació Renfe inundada 19 persones dins d'un tren. FGC a Capellades, petit descarrilament 5 passatgers #Inuncat @emergenciescat"], | |
| ["Anscheinend steht die komplette Neckarwiese unter Wasser! #Hochwasser"]], text_to_classify) | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| classification = gr.Textbox(label="Classification") | |
| with gr.Column(): | |
| classification_score = gr.Number(label="Classification Score") | |
| with gr.Tab("Event Type Classification"): | |
| gr.Markdown( | |
| """ | |
| # Relevance Classifier Demo | |
| This is a demo created to explore floods and wildfire classification in social media posts.\n | |
| Upload .tsv or .csv data file (must contain a text column with social media posts), next enter the name of the text column, choose classifier model, and click 'start prediction'. | |
| """) | |
| with gr.Group(): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| T_file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv']) | |
| with gr.Column(): | |
| T_text_field = gr.Textbox(label="Text field name", value="tweet_text") | |
| T_event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model") | |
| with gr.Accordion("Prediction threshold", open=False): | |
| T_threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False, | |
| info="This value sets a threshold by which texts classified flood or fire are accepted, \ | |
| higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True) | |
| T_predict_button = gr.Button("Start Prediction") | |
| T_examples = gr.Examples([["./samples.tsv", "tweet_content", "jayebaku/XLMRoberta-twitter-crexdata-flood-wildfire-detector", 0.00]], | |
| inputs=[T_file_input, T_text_field, T_event_model, T_threshold]) | |
| gr.Markdown("""Select an ID cell in dataframe to view Embedded tweet""") | |
| T_tweetID = gr.Textbox(visible=False) | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| T_data_filter = gr.Dropdown(visible=False) | |
| T_tweet_embed = gr.HTML("""<div id="tweet-container"></div>""") | |
| with gr.Column(scale=7): | |
| T_data = gr.DataFrame(#headers=["Texts", "event_label", "model_score", "IDs"], | |
| wrap=True, | |
| show_fullscreen_button=True, | |
| show_copy_button=True, | |
| show_row_numbers=True, | |
| show_search="filter", | |
| max_height=1000, | |
| column_widths=["49%","17%","17%","17%"]) | |
| qa_tab = gr.Tab("Question Answering") | |
| with qa_tab: | |
| gr.Markdown( | |
| """ | |
| # Question Answering Demo | |
| This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n | |
| Usage:\n | |
| - Select queries from predefined\n | |
| - Parameters for QA can be editted in sidebar\n | |
| Note: QA process is disabled untill after the relevance classification is done | |
| """) | |
| with gr.Group(): | |
| with gr.Accordion("Parameters", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True) | |
| aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True) | |
| with gr.Column(): | |
| batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True) | |
| topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True) | |
| response_lang = gr.Dropdown(["english", "german", "catalan", "spanish"], label="Response language", value="english", interactive=True) | |
| selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True) | |
| queries_state = gr.State() | |
| qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state]) | |
| query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time") | |
| QA_addqry_button = gr.Button("Add to queries", interactive=False) | |
| QA_run_button = gr.Button("Start QA", interactive=False) | |
| hsummary = gr.Textbox(label="Summary") | |
| qa_tweetID = gr.Textbox(visible=False) | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| qa_df = gr.DataFrame(wrap=True, | |
| show_fullscreen_button=True, | |
| show_copy_button=True, | |
| show_search="filter", | |
| max_height=1000, | |
| column_widths=["10%","70%","20%"]) | |
| with gr.Column(scale=3): | |
| qa_tweet_embed = gr.HTML("""<div id="tweet-container2"></div>""") | |
| # with gr.Tab("Event Type Classification Eval"): | |
| # gr.Markdown( | |
| # """ | |
| # # T4.5 Relevance Classifier Demo | |
| # This is a demo created to explore floods and wildfire classification in social media posts.\n | |
| # Usage:\n | |
| # - Upload .tsv or .csv data file (must contain a text column with social media posts).\n | |
| # - Next, type the name of the text column.\n | |
| # - Then, choose a BERT classifier model from the drop down.\n | |
| # - Finally, click the 'start prediction' buttton.\n | |
| # Evaluation:\n | |
| # - To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n | |
| # - Then, click on the 'Calculate Accuracy' button.\n | |
| # - Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file. | |
| # """) | |
| # with gr.Row(): | |
| # with gr.Column(scale=4): | |
| # file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv']) | |
| # with gr.Column(scale=6): | |
| # text_field = gr.Textbox(label="Text field name", value="tweet_text") | |
| # event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model") | |
| # ETCE_predict_button = gr.Button("Start Prediction") | |
| # with gr.Accordion("Prediction threshold", open=False): | |
| # threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False, | |
| # info="This value sets a threshold by which texts classified flood or fire are accepted, \ | |
| # higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True) | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # gr.Markdown("""### Flood-related""") | |
| # flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
| # with gr.Column(): | |
| # gr.Markdown("""### Fire-related""") | |
| # fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
| # with gr.Column(): | |
| # gr.Markdown("""### None""") | |
| # none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
| # with gr.Row(): | |
| # with gr.Column(scale=5): | |
| # gr.Markdown(r""" | |
| # Accuracy: is the model's ability to make correct predicitons. | |
| # It is the fraction of correct prediction out of the total predictions. | |
| # $$ | |
| # \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100 | |
| # $$ | |
| # Model Confidence: is the mean probabilty of each case | |
| # belonging to their assigned classes. A value of 1 is best. | |
| # """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]) | |
| # gr.Markdown("\n\n\n") | |
| # model_confidence = gr.Number(label="Model Confidence") | |
| # with gr.Column(scale=5): | |
| # correct = gr.Number(label="Number of correct classifications") | |
| # incorrect = gr.Number(label="Number of incorrect classifications") | |
| # accuracy = gr.Number(label="Model Accuracy (%)") | |
| # ETCE_accuracy_button = gr.Button("Calculate Accuracy") | |
| # download_csv = gr.DownloadButton(visible=False) | |
| # num_posts = gr.Number(visible=False) | |
| # data = gr.DataFrame(visible=False) | |
| # data_eval = gr.DataFrame(visible=False) | |
| createEmbedding_js = """ (x) => | |
| { | |
| reloadTwitterWidgets(); | |
| const tweetContainer = document.getElementById("<=CONTAINER-NAME=>"); | |
| tweetContainer.innerHTML = ""; | |
| twttr.widgets.createTweet(x,tweetContainer,{theme: 'dark', dnt: true, align: 'center'}); | |
| } | |
| """ | |
| # Test event listeners | |
| T_predict_button.click( | |
| load_and_classify_csv_dataframe, | |
| inputs=[T_file_input, T_text_field, T_event_model, T_threshold], | |
| outputs=[T_data, T_data_ss_state, T_data_filter, QA_addqry_button, QA_run_button] | |
| ) | |
| T_data.select(T_on_select, None, T_tweetID) | |
| T_tweetID.change(fn=None, inputs=T_tweetID, outputs=None, js=createEmbedding_js.replace("<=CONTAINER-NAME=>", "tweet-container")) | |
| def filter_df(df, filter): | |
| if filter == "All": | |
| result_df = df.copy() | |
| elif filter.startswith("Not"): | |
| result_df = df[df["event_label"]!=filter.split('-')[1]].copy() | |
| else: | |
| result_df = df[df["event_label"]==filter].copy() | |
| return result_df | |
| # Button clicks ETC Eval | |
| # ETCE_predict_button.click( | |
| # load_and_classify_csv, | |
| # inputs=[file_input, text_field, event_model, threshold], | |
| # outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data, QA_addqry_button, QA_run_button]) | |
| # ETCE_accuracy_button.click( | |
| # calculate_accuracy, | |
| # inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data], | |
| # outputs=[incorrect, correct, accuracy, data_eval, download_csv]) | |
| # Button clicks QA | |
| QA_addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state]) | |
| QA_run_button.click(qa_summarise, | |
| inputs=[selected_queries, qa_llm_model, T_text_field, response_lang, T_data_ss_state], | |
| outputs=[hsummary, qa_df]) | |
| qa_df.select(T_on_select, None, qa_tweetID) | |
| qa_tweetID.change(fn=None, inputs=qa_tweetID, outputs=None, js=createEmbedding_js.replace("<=CONTAINER-NAME=>", "tweet-container2")) | |
| # Event listener for single text classification | |
| text_to_classify.submit( | |
| single_classification, | |
| inputs=[text_to_classify, model_sing_classify, threshold_sing_classify], | |
| outputs=[classification, classification_score]) | |
| demo.launch() |