File size: 2,446 Bytes
7be2c11
4c6440f
 
 
48854b7
7be2c11
 
1e431ad
48854b7
7be2c11
 
 
 
48854b7
04cc170
 
 
7be2c11
04cc170
 
 
 
7be2c11
 
 
 
 
 
 
 
 
 
04cc170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48854b7
1e431ad
 
7be2c11
8b6417b
7be2c11
 
 
04cc170
 
7be2c11
1e431ad
 
48854b7
1e431ad
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Prerequisites
import os
os.environ["OMP_NUM_THREADS"] = "1"  # Set 1, 2, or 4 depending on CPU usage

from transformers import pipeline
import json
import pandas as pd
import gradio as gr

# get candidate labels
with open("packing_label_structure.json", "r") as file:
    candidate_labels = json.load(file)
keys_list = list(candidate_labels.keys())

# Load packing item data
with open("packing_templates_self_supported_offgrid_expanded.json", "r") as file:
    packing_items = json.load(file)

# function and gradio app
def classify(model_name, trip_descr, cut_off = 0.5):
    classifier = pipeline("zero-shot-classification", model=model_name)
    ## Create and fill dataframe with class predictions
    df = pd.DataFrame(columns=['superclass', 'pred_class'])
    for i, key in enumerate(keys_list):
        if key == 'activities':
            result = classifier(trip_descr, candidate_labels[key], multi_label=True)
            indices = [i for i, score in enumerate(result['scores']) if score > cut_off]
            classes = [result['labels'][i] for i in indices]
        else:
            result = classifier(trip_descr, candidate_labels[key])
            classes = result["labels"][0]
        df.loc[i] = [key, classes]

    ## Look up and return list of items to pack based on class predictions
    # make list from dataframe column
    all_classes = [elem for x in df["pred_class"] for elem in (x if isinstance(x, list) else [x])]
    # look up packing items for each class/key
    list_of_list_of_items = [packing_items.get(k, []) for k in all_classes]
    # combine lists and remove doubble entries
    flat_unique = []
    for sublist in list_of_list_of_items:
        for item in sublist:
            if item not in flat_unique:
                flat_unique.append(item)
    # sort alphabetically to notice duplicates
    sorted_list = sorted(flat_unique)  
    return df, "\n".join(sorted_list)

demo = gr.Interface(
    fn=classify,
    inputs=[
        gr.Textbox(label="Model name", value = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"),
        gr.Textbox(label="Trip description"),
        gr.Number(label="Activity cut-off", value = 0.5),
    ],
    # outputs="dataframe",
    outputs=[gr.Dataframe(label="DataFrame"), gr.Textbox(label="List of words")],
    title="Trip classification",
    description="Enter a text describing your trip",
)

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()