Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import gradio as gr | |
| import pandas as pd | |
| from classifier import classify | |
| HFTOKEN = os.environ["HF_TOKEN"] | |
| def load_and_analyze_csv(file, text_field, event_model): | |
| df = pd.read_table(file.name) | |
| if text_field not in df.columns: | |
| raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.") | |
| floods, fires, nones = [], [], [] | |
| for post in df[text_field].to_list(): | |
| res = classify(post, event_model, HFTOKEN) | |
| if res["event"] == 'flood': | |
| floods.append(post) | |
| elif res["event"] == 'fire': | |
| fires.append(post) | |
| else: | |
| nones.append(post) | |
| fire_related = gr.CheckboxGroup(choices=fires) | |
| flood_related = gr.CheckboxGroup(choices=floods) | |
| not_related = gr.CheckboxGroup(choices=nones) | |
| # time.sleep(5) | |
| return fire_related, flood_related, not_related | |
| def analyze_selected_texts(selections): | |
| selected_texts = selections | |
| analysis_results = [f"Word Count: {len(text.split())}" for text in selected_texts] | |
| result_df = pd.DataFrame({"Selected Text": selected_texts, "Analysis": analysis_results}) | |
| return result_df | |
| with gr.Blocks() as demo: | |
| event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"] | |
| with gr.Tab("Event Type Classification"): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=4): | |
| file_input = gr.File(label="Upload CSV File") | |
| with gr.Column(scale=6): | |
| text_field = gr.Textbox(label="Text field name", value="tweet_text") | |
| event_model = gr.Dropdown(event_models, label="Select classification model") | |
| predict_button = gr.Button("Start Prediction") | |
| with gr.Row(): # XXX confirm this is not a problem later --equal_height=True | |
| with gr.Column(): | |
| gr.Markdown("""### Flood-related""") | |
| fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications") | |
| with gr.Column(): | |
| gr.Markdown("""### Fire-related""") | |
| flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications") | |
| with gr.Column(): | |
| gr.Markdown("""### None""") | |
| none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications") | |
| predict_button.click(load_and_analyze_csv, inputs=[file_input, text_field, event_model], | |
| outputs=[fire_checkbox_output, flood_checkbox_output, none_checkbox_output]) | |
| with gr.Tab("Question Answering"): | |
| # XXX Add some button disabling here, if the classification process is not completed first XXX | |
| analysis_button = gr.Button("Analyze Selected Texts") | |
| analysis_output = gr.DataFrame(headers=["Selected Text", "Analysis"]) | |
| analysis_button.click(analyze_selected_texts, inputs=fire_checkbox_output, outputs=analysis_output) | |
| demo.launch() |