MohamedTry commited on
Commit
3957394
·
verified ·
1 Parent(s): a766287

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -61
app.py CHANGED
@@ -1,75 +1,56 @@
1
-
2
- import os
3
- from fastapi import FastAPI, HTTPException
4
- from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
7
 
8
- # === Config ===
9
- HF_TOKEN = os.getenv("HF_TOKEN") # add this as a Secret in your Space
10
  MODEL_T = os.getenv("MODEL_T", "jkefeli/CancerStage_Classifier_T")
11
  MODEL_N = os.getenv("MODEL_N", "jkefeli/CancerStage_Classifier_N")
12
  MODEL_M = os.getenv("MODEL_M", "jkefeli/CancerStage_Classifier_M")
13
 
14
- # Initialize FastAPI
15
- app = FastAPI(title="TNM Endpoint", version="1.0.0")
16
 
17
- # CORS (optional): allow all origins by default; tighten for production
18
- app.add_middleware(
19
- CORSMiddleware,
20
- allow_origins=["*"],
21
- allow_credentials=True,
22
- allow_methods=["*"],
23
- allow_headers=["*"],
24
- )
25
-
26
- # Init Hugging Face Inference clients (remote inference, no heavy compute on your Space)
27
- clients = {
28
- "T": InferenceClient(model=MODEL_T, token=HF_TOKEN),
29
- "N": InferenceClient(model=MODEL_N, token=HF_TOKEN),
30
- "M": InferenceClient(model=MODEL_M, token=HF_TOKEN),
31
- }
32
-
33
- class ReportInput(BaseModel):
34
  text: str
35
 
36
- @app.get("/healthz")
37
- def healthz():
38
- return {"status": "ok", "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
39
-
40
- def _classify(client: InferenceClient, text: str):
41
- # Uses HF Inference API task: text-classification
42
- # Returns best label with score
43
- try:
44
- outputs = client.text_classification(text, wait_for_model=True)
45
- if not outputs:
46
- raise ValueError("Empty response from model")
47
- best = max(outputs, key=lambda x: x.get("score", 0))
48
- return {"label": best.get("label", "UNKNOWN"), "score": float(best.get("score", 0.0)), "raw": outputs}
49
- except Exception as e:
50
- raise HTTPException(status_code=502, detail=f"Inference error: {e}")
51
 
52
  @app.post("/predict_tnm")
53
- def predict_tnm(input: ReportInput):
54
- text = (input.text or "").strip()
55
- if not text:
56
- raise HTTPException(status_code=400, detail="Empty 'text'")
57
- # Optionally hard-truncate very long inputs to avoid API limits
58
- if len(text) > 20000:
59
- text = text[:20000]
60
-
61
- preds = {}
62
- for key, client in clients.items():
63
- preds[key] = _classify(client, text)
64
 
65
- t = preds["T"]["label"]
66
- n = preds["N"]["label"]
67
- m = preds["M"]["label"]
68
- tnm_string = f"{t} {n} {m}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- return {
71
- "input_chars": len(text),
72
- "tnm": preds,
73
- "tnm_string": tnm_string,
74
- "meta": {"models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}}
75
- }
 
1
+ from fastapi import FastAPI, Body
 
 
 
2
  from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
+ import os
5
+
6
+ # FastAPI app
7
+ app = FastAPI(title="TNM Cancer Stage Endpoint", version="1.0")
8
+
9
+ # HF token from environment (set in Space secrets)
10
+ HF_TOKEN = os.getenv("HF_TOKEN")
11
 
12
+ # Models (يمكنك تعديلها من Variables في Hugging Face Space)
 
13
  MODEL_T = os.getenv("MODEL_T", "jkefeli/CancerStage_Classifier_T")
14
  MODEL_N = os.getenv("MODEL_N", "jkefeli/CancerStage_Classifier_N")
15
  MODEL_M = os.getenv("MODEL_M", "jkefeli/CancerStage_Classifier_M")
16
 
17
+ # Hugging Face client
18
+ client = InferenceClient(token=HF_TOKEN)
19
 
20
+ class Report(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  text: str
22
 
23
+ @app.get("/")
24
+ def health_check():
25
+ return {
26
+ "status": "running",
27
+ "models": {"T": MODEL_T, "N": MODEL_N, "M": MODEL_M}
28
+ }
 
 
 
 
 
 
 
 
 
29
 
30
  @app.post("/predict_tnm")
31
+ def predict_tnm(report: Report = Body(...)):
32
+ text = report.text
 
 
 
 
 
 
 
 
 
33
 
34
+ try:
35
+ # Call each model
36
+ pred_T = client.text_classification(text, model=MODEL_T)
37
+ pred_N = client.text_classification(text, model=MODEL_N)
38
+ pred_M = client.text_classification(text, model=MODEL_M)
39
+
40
+ # Extract top labels
41
+ t_label = pred_T[0]["label"] if pred_T else None
42
+ n_label = pred_N[0]["label"] if pred_N else None
43
+ m_label = pred_M[0]["label"] if pred_M else None
44
+
45
+ return {
46
+ "input_chars": len(text),
47
+ "tnm": {
48
+ "T": pred_T[0] if pred_T else None,
49
+ "N": pred_N[0] if pred_N else None,
50
+ "M": pred_M[0] if pred_M else None,
51
+ },
52
+ "tnm_string": f"{t_label} {n_label} {m_label}"
53
+ }
54
 
55
+ except Exception as e:
56
+ return {"error": str(e)}