packing_list / app.py
AnjaJuana
Auto-update from GitHub
aa8447d
raw
history blame
1.67 kB
# Prerequisites
from transformers import pipeline
import json
import pandas as pd
import gradio as gr
# get candidate labels
with open("packing_label_structure.json", "r") as file:
candidate_labels = json.load(file)
keys_list = list(candidate_labels.keys())
# Load test data (in list of dictionaries)
with open("test_data.json", "r") as file:
packing_data = json.load(file)
# function and gradio app
model_name = "facebook/bart-large-mnli"
classifier = pipeline("zero-shot-classification", model=model_name)
cut_off = 0.5 # used to choose which activities are relevant
def classify(#model_name,
trip_descr, cut_off):
# Create an empty DataFrame with specified columns
df = pd.DataFrame(columns=['superclass', 'pred_class'])
for i, key in enumerate(keys_list):
if key == 'activities':
result = classifier(trip_descr, candidate_labels[key], multi_label=True)
indices = [i for i, score in enumerate(result['scores']) if score > cut_off]
classes = [result['labels'][i] for i in indices]
else:
result = classifier(trip_descr, candidate_labels[key])
classes = result["labels"][0]
df.loc[i] = [key, classes]
return df
demo = gr.Interface(
fn=classify,
inputs=[
#gr.Textbox(label="Model name", value = "facebook/bart-large-mnli"),
gr.Textbox(label="Trip description"),
gr.Number(label="Activity cut-off", value = 0.5),
],
outputs="dataframe",
title="Trip classification",
description="Enter a text describing your trip",
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()