TNM / app.py
MohamedTry's picture
Upload 3 files
c22cacc verified
raw
history blame
2.49 kB
import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from huggingface_hub import InferenceClient
# === Config ===
HF_TOKEN = os.getenv("HF_TOKEN") # add this as a Secret in your Space
MODEL_T = os.getenv("MODEL_T", "jkefeli/CancerStage_Classifier_T")
MODEL_N = os.getenv("MODEL_N", "jkefeli/CancerStage_Classifier_N")
MODEL_M = os.getenv("MODEL_M", "jkefeli/CancerStage_Classifier_M")
# Initialize FastAPI
app = FastAPI(title="TNM Endpoint", version="1.0.0")
# CORS (optional): allow all origins by default; tighten for production
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Init Hugging Face Inference clients (remote inference, no heavy compute on your Space)
clients = {
"T": InferenceClient(model=MODEL_T, token=HF_TOKEN),
"N": InferenceClient(model=MODEL_N, token=HF_TOKEN),
"M": InferenceClient(model=MODEL_M, token=HF_TOKEN),
}
class ReportInput(BaseModel):
text: str
@app.get("/healthz")
def healthz():
return {"status": "ok", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
def _classify(client: InferenceClient, text: str):
# Uses HF Inference API task: text-classification
# Returns best label with score
try:
outputs = client.text_classification(text, wait_for_model=True)
if not outputs:
raise ValueError("Empty response from model")
best = max(outputs, key=lambda x: x.get("score", 0))
return {"label": best.get("label", "UNKNOWN"), "score": float(best.get("score", 0.0)), "raw": outputs}
except Exception as e:
raise HTTPException(status_code=502, detail=f"Inference error: {e}")
@app.post("/predict_tnm")
def predict_tnm(input: ReportInput):
text = (input.text or "").strip()
if not text:
raise HTTPException(status_code=400, detail="Empty 'text'")
# Optionally hard-truncate very long inputs to avoid API limits
if len(text) > 20000:
text = text[:20000]
preds = {}
for key, client in clients.items():
preds[key] = _classify(client, text)
t = preds["T"]["label"]
n = preds["N"]["label"]
m = preds["M"]["label"]
tnm_string = f"{t} {n} {m}"
return {
"input_chars": len(text),
"tnm": preds,
"tnm_string": tnm_string,
"meta": {"models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
}