File size: 682 Bytes
48854b7
1e431ad
48854b7
1e431ad
48854b7
 
1e431ad
 
 
48854b7
1e431ad
 
 
48854b7
1e431ad
 
 
 
 
 
 
48854b7
1e431ad
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import pipeline
import gradio as gr

# Load the model and create a pipeline for zero-shot classification
classifier = pipeline("zero-shot-classification", model="facebook/bart-base")

# Load labels from a txt file
with open("labels.txt", "r", encoding="utf-8") as f:
    class_labels = [line.strip() for line in f if line.strip()]

# Define the Gradio interface
def classify(text):
    return classifier(text, class_labels)

demo = gr.Interface(
    fn=classify,
    inputs="text",
    outputs="json",
    title="Zero-Shot Classification",
    description="Enter a text describing your trip",
)

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()