Spaces:
Running
Running
| from fastapi import FastAPI, Body | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, BigBirdForSequenceClassification | |
| from scipy.special import softmax | |
| import torch | |
| # Initialize FastAPI | |
| app = FastAPI(title="TNM Endpoint", version="1.0") | |
| # Models (TNM) from Hugging Face | |
| MODEL_T = "jkefeli/CancerStage_Classifier_T" | |
| MODEL_N = "jkefeli/CancerStage_Classifier_N" | |
| MODEL_M = "jkefeli/CancerStage_Classifier_M" | |
| # Load tokenizer once | |
| tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird") | |
| # Load models once (CPU mode) | |
| model_T = BigBirdForSequenceClassification.from_pretrained(MODEL_T) | |
| model_N = BigBirdForSequenceClassification.from_pretrained(MODEL_N) | |
| model_M = BigBirdForSequenceClassification.from_pretrained(MODEL_M) | |
| class Report(BaseModel): | |
| text: str | |
| def predict_stage(text, model): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=2048) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = softmax(outputs.logits.numpy(), axis=1) | |
| pred_class = probs.argmax(axis=1)[0] | |
| return {"class": int(pred_class), "probs": probs.tolist()} | |
| def health_check(): | |
| return {"status": "running", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}} | |
| def predict_tnm(report: Report = Body(...)): | |
| text = report.text | |
| try: | |
| t_result = predict_stage(text, model_T) | |
| n_result = predict_stage(text, model_N) | |
| m_result = predict_stage(text, model_M) | |
| return { | |
| "input": text, | |
| "TNM_prediction": { | |
| "T": t_result, | |
| "N": n_result, | |
| "M": m_result | |
| } | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} |