from fastapi import FastAPI, Body from pydantic import BaseModel from transformers import AutoTokenizer, BigBirdForSequenceClassification from scipy.special import softmax import torch # Initialize FastAPI app = FastAPI(title="TNM Endpoint", version="1.0") # Models (TNM) from Hugging Face MODEL_T = "jkefeli/CancerStage_Classifier_T" MODEL_N = "jkefeli/CancerStage_Classifier_N" MODEL_M = "jkefeli/CancerStage_Classifier_M" # Load tokenizer once tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird") # Load models once (CPU mode) model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T) model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N) model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M) class Report(BaseModel): text: str def predict_stage(text, model): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048) with torch.no_grad(): outputs = model(**inputs) probs = softmax(outputs.logits.numpy(), axis=1) pred_class = probs.argmax(axis=1)[0] return {"class": int(pred_class), "probs": probs.tolist()} @app.get("/") def health_check(): return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}} @app.post("/predict_tnm") def predict_tnm(report: Report = Body(...)): text = report.text try: t_result = predict_stage(text, model_T) n_result = predict_stage(text, model_N) m_result = predict_stage(text, model_M) return { "input": text, "TNM_prediction": { "T": t_result, "N": n_result, "M": m_result } } except Exception as e: return {"error": str(e)}