File size: 2,491 Bytes
c22cacc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from huggingface_hub import InferenceClient

# === Config ===
HF_TOKEN = os.getenv("HF_TOKEN")  # add this as a Secret in your Space
MODEL_T = os.getenv("MODEL_T", "jkefeli/CancerStage_Classifier_T")
MODEL_N = os.getenv("MODEL_N", "jkefeli/CancerStage_Classifier_N")
MODEL_M = os.getenv("MODEL_M", "jkefeli/CancerStage_Classifier_M")

# Initialize FastAPI
app = FastAPI(title="TNM Endpoint", version="1.0.0")

# CORS (optional): allow all origins by default; tighten for production
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Init Hugging Face Inference clients (remote inference, no heavy compute on your Space)
clients = {
    "T": InferenceClient(model=MODEL_T, token=HF_TOKEN),
    "N": InferenceClient(model=MODEL_N, token=HF_TOKEN),
    "M": InferenceClient(model=MODEL_M, token=HF_TOKEN),
}

class ReportInput(BaseModel):
    text: str

@app.get("/healthz")
def healthz():
    return {"status": "ok", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}

def _classify(client: InferenceClient, text: str):
    # Uses HF Inference API task: text-classification
    # Returns best label with score
    try:
        outputs = client.text_classification(text, wait_for_model=True)
        if not outputs:
            raise ValueError("Empty response from model")
        best = max(outputs, key=lambda x: x.get("score", 0))
        return {"label": best.get("label", "UNKNOWN"), "score": float(best.get("score", 0.0)), "raw": outputs}
    except Exception as e:
        raise HTTPException(status_code=502, detail=f"Inference error: {e}")

@app.post("/predict_tnm")
def predict_tnm(input: ReportInput):
    text = (input.text or "").strip()
    if not text:
        raise HTTPException(status_code=400, detail="Empty 'text'")
    # Optionally hard-truncate very long inputs to avoid API limits
    if len(text) > 20000:
        text = text[:20000]

    preds = {}
    for key, client in clients.items():
        preds[key] = _classify(client, text)

    t = preds["T"]["label"]
    n = preds["N"]["label"]
    m = preds["M"]["label"]
    tnm_string = f"{t} {n} {m}"

    return {
        "input_chars": len(text),
        "tnm": preds,
        "tnm_string": tnm_string,
        "meta": {"models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
    }