TNM / app.py
MohamedTry's picture
Update app.py
e93b18a verified
from fastapi import FastAPI, Body
from pydantic import BaseModel
from transformers import AutoTokenizer, BigBirdForSequenceClassification
from scipy.special import softmax
import torch
# Initialize FastAPI
app = FastAPI(title="TNM Endpoint", version="1.0")
# Models (TNM) from Hugging Face
MODEL_T = "jkefeli/CancerStage_Classifier_T"
MODEL_N = "jkefeli/CancerStage_Classifier_N"
MODEL_M = "jkefeli/CancerStage_Classifier_M"
# Load tokenizer once
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
# Load models once (CPU mode)
model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T)
model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N)
model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M)
class Report(BaseModel):
text: str
def predict_stage(text, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048)
with torch.no_grad():
outputs = model(**inputs)
probs = softmax(outputs.logits.numpy(), axis=1)
pred_class = probs.argmax(axis=1)[0]
return {"class": int(pred_class), "probs": probs.tolist()}
@app.get("/")
def health_check():
return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
@app.post("/predict_tnm")
def predict_tnm(report: Report = Body(...)):
text = report.text
try:
t_result = predict_stage(text, model_T)
n_result = predict_stage(text, model_N)
m_result = predict_stage(text, model_M)
return {
"input": text,
"TNM_prediction": {
"T": t_result,
"N": n_result,
"M": m_result
}
}
except Exception as e:
return {"error": str(e)}