ux-zeroshot / app.py
adithimshrouthy's picture
Update app.py
401d729 verified
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict
from transformers import pipeline
import os
app = FastAPI()
# === MODELS (unchanged per your request) ===
# ZSC_MODEL="MoritzLaurer/deberta-v3-large-zeroshot-v2.0-c"
ZSC_MODEL = "facebook/bart-large-mnli"
SA_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
SUM_MODEL = "t5-small"
DEFAULT_LABELS = ["Usability","Performance","Visual Design","Feedback","Navigation","Responsiveness"]
TEMPLATE = "This feedback is primarily about {}."
zsc = pipeline("zero-shot-classification", model=ZSC_MODEL)
sa = pipeline("sentiment-analysis", model=SA_MODEL)
summ = pipeline("summarization", model=SUM_MODEL)
class ZSCReq(BaseModel):
text: str
labels: List[str] = []
multi_label: bool = False
template: str = TEMPLATE
class SAReq(BaseModel):
text: str
class SumReq(BaseModel):
text: str
max_length: int = 60
min_length: int = 20
do_sample: bool = False
@app.on_event("startup")
def warmup():
try:
_ = zsc("warmup", candidate_labels=DEFAULT_LABELS, multi_label=False, truncation=True)
except Exception as e:
print(f"[warmup] skipped: {e}")
@app.get("/")
def health():
return {"status": "ok", "model": ZSC_MODEL, "labels": DEFAULT_LABELS}
@app.post("/predict")
def predict(req: ZSCReq):
labels = req.labels or DEFAULT_LABELS
out = zsc(
req.text,
candidate_labels=labels,
multi_label=False, # single best label
hypothesis_template=(req.template or "This feedback is primarily about {}."),
truncation=True,
)
pairs = sorted(zip(out["labels"], out["scores"]), key=lambda p: float(p[1]), reverse=True)
if not pairs:
return {"labels": [], "scores": []}
lbls, scs = zip(*pairs)
return {"labels": list(lbls), "scores": [float(s) for s in scs]}
@app.post("/sa")
def sentiment(req: SAReq):
r = sa(req.text)[0]
return {"label": r["label"], "score": float(r["score"])}
@app.post("/sum")
def summarize(req: SumReq):
r = summ(
req.text,
max_length=req.max_length,
min_length=req.min_length,
do_sample=req.do_sample,
truncation=True,
)[0]
return {"summary_text": r["summary_text"]}