from fastapi import FastAPI from pydantic import BaseModel from typing import List, Dict from transformers import pipeline import os app = FastAPI() # === MODELS (unchanged per your request) === # ZSC_MODEL="MoritzLaurer/deberta-v3-large-zeroshot-v2.0-c" ZSC_MODEL = "facebook/bart-large-mnli" SA_MODEL = "distilbert-base-uncased-finetuned-sst-2-english" SUM_MODEL = "t5-small" DEFAULT_LABELS = ["Usability","Performance","Visual Design","Feedback","Navigation","Responsiveness"] TEMPLATE = "This feedback is primarily about {}." zsc = pipeline("zero-shot-classification", model=ZSC_MODEL) sa = pipeline("sentiment-analysis", model=SA_MODEL) summ = pipeline("summarization", model=SUM_MODEL) class ZSCReq(BaseModel): text: str labels: List[str] = [] multi_label: bool = False template: str = TEMPLATE class SAReq(BaseModel): text: str class SumReq(BaseModel): text: str max_length: int = 60 min_length: int = 20 do_sample: bool = False @app.on_event("startup") def warmup(): try: _ = zsc("warmup", candidate_labels=DEFAULT_LABELS, multi_label=False, truncation=True) except Exception as e: print(f"[warmup] skipped: {e}") @app.get("/") def health(): return {"status": "ok", "model": ZSC_MODEL, "labels": DEFAULT_LABELS} @app.post("/predict") def predict(req: ZSCReq): labels = req.labels or DEFAULT_LABELS out = zsc( req.text, candidate_labels=labels, multi_label=False, # single best label hypothesis_template=(req.template or "This feedback is primarily about {}."), truncation=True, ) pairs = sorted(zip(out["labels"], out["scores"]), key=lambda p: float(p[1]), reverse=True) if not pairs: return {"labels": [], "scores": []} lbls, scs = zip(*pairs) return {"labels": list(lbls), "scores": [float(s) for s in scs]} @app.post("/sa") def sentiment(req: SAReq): r = sa(req.text)[0] return {"label": r["label"], "score": float(r["score"])} @app.post("/sum") def summarize(req: SumReq): r = summ( req.text, max_length=req.max_length, min_length=req.min_length, do_sample=req.do_sample, truncation=True, )[0] return {"summary_text": r["summary_text"]}