import os from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from huggingface_hub import InferenceClient # === Config === HF_TOKEN = os.getenv("HF_TOKEN") # add this as a Secret in your Space MODEL_T = os.getenv("MODEL_T", "jkefeli/CancerStage_Classifier_T") MODEL_N = os.getenv("MODEL_N", "jkefeli/CancerStage_Classifier_N") MODEL_M = os.getenv("MODEL_M", "jkefeli/CancerStage_Classifier_M") # Initialize FastAPI app = FastAPI(title="TNM Endpoint", version="1.0.0") # CORS (optional): allow all origins by default; tighten for production app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Init Hugging Face Inference clients (remote inference, no heavy compute on your Space) clients = { "T": InferenceClient(model=MODEL_T, token=HF_TOKEN), "N": InferenceClient(model=MODEL_N, token=HF_TOKEN), "M": InferenceClient(model=MODEL_M, token=HF_TOKEN), } class ReportInput(BaseModel): text: str @app.get("/healthz") def healthz(): return {"status": "ok", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}} def _classify(client: InferenceClient, text: str): # Uses HF Inference API task: text-classification # Returns best label with score try: outputs = client.text_classification(text, wait_for_model=True) if not outputs: raise ValueError("Empty response from model") best = max(outputs, key=lambda x: x.get("score", 0)) return {"label": best.get("label", "UNKNOWN"), "score": float(best.get("score", 0.0)), "raw": outputs} except Exception as e: raise HTTPException(status_code=502, detail=f"Inference error: {e}") @app.post("/predict_tnm") def predict_tnm(input: ReportInput): text = (input.text or "").strip() if not text: raise HTTPException(status_code=400, detail="Empty 'text'") # Optionally hard-truncate very long inputs to avoid API limits if len(text) > 20000: text = text[:20000] preds = {} for key, client in clients.items(): preds[key] = _classify(client, text) t = preds["T"]["label"] n = preds["N"]["label"] m = preds["M"]["label"] tnm_string = f"{t} {n} {m}" return { "input_chars": len(text), "tnm": preds, "tnm_string": tnm_string, "meta": {"models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}} }