Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,14 +35,6 @@ def load_and_classify_csv(file, text_field, event_model):
|
|
| 35 |
not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
|
| 36 |
|
| 37 |
return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df
|
| 38 |
-
|
| 39 |
-
def qa_process(selections):
|
| 40 |
-
selected_texts = selections
|
| 41 |
-
|
| 42 |
-
analysis_results = [f"Word Count: {len(text.split())}" for text in selected_texts]
|
| 43 |
-
|
| 44 |
-
result_df = pd.DataFrame({"Selected Text": selected_texts, "Analysis": analysis_results})
|
| 45 |
-
return result_df
|
| 46 |
|
| 47 |
def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
|
| 48 |
posts = data_df[text_field].to_list()
|
|
@@ -94,6 +86,42 @@ def add_query(to_add, history):
|
|
| 94 |
if to_add not in history:
|
| 95 |
history.append(to_add)
|
| 96 |
return gr.CheckboxGroup(choices=history), history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
with gr.Blocks() as demo:
|
| 99 |
event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"]
|
|
@@ -209,7 +237,9 @@ with gr.Blocks() as demo:
|
|
| 209 |
|
| 210 |
|
| 211 |
addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
|
| 212 |
-
qa_button.click(qa_process,
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
|
| 215 |
demo.launch()
|
|
|
|
| 35 |
not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
|
| 36 |
|
| 37 |
return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
|
| 40 |
posts = data_df[text_field].to_list()
|
|
|
|
| 86 |
if to_add not in history:
|
| 87 |
history.append(to_add)
|
| 88 |
return gr.CheckboxGroup(choices=history), history
|
| 89 |
+
|
| 90 |
+
def qa_process(selected_queries, qa_llm_model, aggregator,
|
| 91 |
+
batch_size, topk, text_field, data_df):
|
| 92 |
+
|
| 93 |
+
emb_model = 'multi-qa-mpnet-base-dot-v1'
|
| 94 |
+
contexts = []
|
| 95 |
+
|
| 96 |
+
queries_df = pd.DataFrame({'id':[j for j in range(len(selected_queries))],'query': selected_queries})
|
| 97 |
+
|
| 98 |
+
tweets_df = data_df[[text_field]]
|
| 99 |
+
tweets_df.reset_index(inplace=True)
|
| 100 |
+
tweets_df.rename(columns={"index": "order"},inplace=True)
|
| 101 |
+
|
| 102 |
+
gr.Info("Loading GENRA pipeline....")
|
| 103 |
+
genra = GenraPipeline(qa_llm_model, emb_model, aggregator, contexts)
|
| 104 |
+
gr.Info("Waiting for data...")
|
| 105 |
+
batches = [tweets_df[i:i+batch_size] for i in range(0,len(tweets_df),batch_size)]
|
| 106 |
+
|
| 107 |
+
genra_answers = []
|
| 108 |
+
summarize_batch = True
|
| 109 |
+
for batch_number, tweets in enumerate(batches):
|
| 110 |
+
gr.Info(f"Populating index for batch {batch_number}")
|
| 111 |
+
genra.qa_indexer.index_dataframe(tweets)
|
| 112 |
+
gr.Info(f"Performing retrieval for batch {batch_number}")
|
| 113 |
+
genra.retrieval(batch_number, queries_df, topk, summarize_batch)
|
| 114 |
+
|
| 115 |
+
gr.Info("Processed all batches!")
|
| 116 |
+
# result ------ genra.answers_store
|
| 117 |
+
|
| 118 |
+
summary = genra.summarize_history(queries_df)
|
| 119 |
+
|
| 120 |
+
analysis_results = [f"Word Count: {len(text.split())}" for text in selected_queries]
|
| 121 |
+
|
| 122 |
+
result_df = pd.DataFrame({"Selected Text": selected_queries, "Analysis": analysis_results})
|
| 123 |
+
return result_df, summary
|
| 124 |
+
|
| 125 |
|
| 126 |
with gr.Blocks() as demo:
|
| 127 |
event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"]
|
|
|
|
| 237 |
|
| 238 |
|
| 239 |
addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
|
| 240 |
+
qa_button.click(qa_process,
|
| 241 |
+
inputs=[selected_queries, qa_llm_model, aggregator, batch_size, topk, text_field, data],
|
| 242 |
+
outputs=[analysis_output, ])
|
| 243 |
+
|
| 244 |
|
| 245 |
demo.launch()
|