Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from huggingface_hub import InferenceClient | |
| # === Config === | |
| HF_TOKEN = os.getenv("HF_TOKEN") # add this as a Secret in your Space | |
| MODEL_T = os.getenv("MODEL_T", "jkefeli/CancerStage_Classifier_T") | |
| MODEL_N = os.getenv("MODEL_N", "jkefeli/CancerStage_Classifier_N") | |
| MODEL_M = os.getenv("MODEL_M", "jkefeli/CancerStage_Classifier_M") | |
| # Initialize FastAPI | |
| app = FastAPI(title="TNM Endpoint", version="1.0.0") | |
| # CORS (optional): allow all origins by default; tighten for production | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Init Hugging Face Inference clients (remote inference, no heavy compute on your Space) | |
| clients = { | |
| "T": InferenceClient(model=MODEL_T, token=HF_TOKEN), | |
| "N": InferenceClient(model=MODEL_N, token=HF_TOKEN), | |
| "M": InferenceClient(model=MODEL_M, token=HF_TOKEN), | |
| } | |
| class ReportInput(BaseModel): | |
| text: str | |
| def healthz(): | |
| return {"status": "ok", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}} | |
| def _classify(client: InferenceClient, text: str): | |
| # Uses HF Inference API task: text-classification | |
| # Returns best label with score | |
| try: | |
| outputs = client.text_classification(text, wait_for_model=True) | |
| if not outputs: | |
| raise ValueError("Empty response from model") | |
| best = max(outputs, key=lambda x: x.get("score", 0)) | |
| return {"label": best.get("label", "UNKNOWN"), "score": float(best.get("score", 0.0)), "raw": outputs} | |
| except Exception as e: | |
| raise HTTPException(status_code=502, detail=f"Inference error: {e}") | |
| def predict_tnm(input: ReportInput): | |
| text = (input.text or "").strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Empty 'text'") | |
| # Optionally hard-truncate very long inputs to avoid API limits | |
| if len(text) > 20000: | |
| text = text[:20000] | |
| preds = {} | |
| for key, client in clients.items(): | |
| preds[key] = _classify(client, text) | |
| t = preds["T"]["label"] | |
| n = preds["N"]["label"] | |
| m = preds["M"]["label"] | |
| tnm_string = f"{t} {n} {m}" | |
| return { | |
| "input_chars": len(text), | |
| "tnm": preds, | |
| "tnm_string": tnm_string, | |
| "meta": {"models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}} | |
| } | |