AshenH commited on
Commit
2dcd5ce
·
verified ·
1 Parent(s): 7caa980

Update tools/predict_tool.py

Browse files
Files changed (1) hide show
  1. tools/predict_tool.py +86 -16
tools/predict_tool.py CHANGED
@@ -1,32 +1,102 @@
 
1
  import os
 
2
  import pandas as pd
3
  import joblib
 
 
4
  from huggingface_hub import hf_hub_download
5
  from utils.config import AppConfig
6
  from utils.tracing import Tracer
7
 
 
8
  class PredictTool:
 
 
 
 
 
 
 
 
 
 
 
 
9
  def __init__(self, cfg: AppConfig, tracer: Tracer):
10
  self.cfg = cfg
11
  self.tracer = tracer
12
  self._model = None
13
- self._feature_meta = None
 
 
14
 
15
  def _ensure_loaded(self):
16
- if self._model is None:
17
- path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
18
- self._model = joblib.load(path)
19
- meta = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="feature_metadata.json", token=os.getenv("HF_TOKEN"))
20
- import json
21
- with open(meta, "r") as f:
22
- self._feature_meta = json.load(f)
23
-
24
- def run(self, df: pd.DataFrame) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self._ensure_loaded()
26
- use_cols = self._feature_meta.get("feature_order", list(df.columns))
27
- X = df[use_cols].copy()
28
- preds = self._model.predict_proba(X)[:, 1] if hasattr(self._model, "predict_proba") else self._model.predict(X)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  out = df.copy()
30
- out[self._feature_meta.get("prediction_column", "prediction")] = preds
31
- self.tracer.trace_event("predict", {"rows": len(out)})
32
- return out
 
 
 
 
1
+ # space/tools/predict_tool.py
2
  import os
3
+ import json
4
  import pandas as pd
5
  import joblib
6
+ from typing import Optional, List
7
+
8
  from huggingface_hub import hf_hub_download
9
  from utils.config import AppConfig
10
  from utils.tracing import Tracer
11
 
12
+
13
  class PredictTool:
14
+ """
15
+ Loads a sklearn-compatible tabular model artifact from a private/public
16
+ Hugging Face repo and runs batch predictions on a DataFrame.
17
+ Expects:
18
+ - model.pkl
19
+ - feature_metadata.json (optional but recommended)
20
+ {
21
+ "feature_order": ["col1","col2",...],
22
+ "prediction_column": "prediction",
23
+ "task": "classification" | "regression"
24
+ }
25
+ """
26
  def __init__(self, cfg: AppConfig, tracer: Tracer):
27
  self.cfg = cfg
28
  self.tracer = tracer
29
  self._model = None
30
+ self._feature_meta = {}
31
+ self._pred_col = "prediction"
32
+ self._feature_order: Optional[List[str]] = None
33
 
34
  def _ensure_loaded(self):
35
+ if self._model is not None:
36
+ return
37
+
38
+ token = os.getenv("HF_TOKEN") # OK if None for public repos
39
+ repo = self.cfg.hf_model_repo
40
+
41
+ model_path = hf_hub_download(
42
+ repo_id=repo,
43
+ filename="model.pkl",
44
+ token=token
45
+ )
46
+ self._model = joblib.load(model_path)
47
+
48
+ # feature metadata is optional; handle gracefully
49
+ try:
50
+ meta_path = hf_hub_download(
51
+ repo_id=repo,
52
+ filename="feature_metadata.json",
53
+ token=token
54
+ )
55
+ with open(meta_path, "r", encoding="utf-8") as f:
56
+ self._feature_meta = json.load(f) or {}
57
+ except Exception:
58
+ self._feature_meta = {}
59
+
60
+ self._pred_col = self._feature_meta.get("prediction_column", "prediction")
61
+ self._feature_order = self._feature_meta.get("feature_order")
62
+
63
+ def _select_features(self, df: pd.DataFrame) -> pd.DataFrame:
64
+ if self._feature_order:
65
+ # keep only features in the trained order, ignore extras
66
+ missing = [c for c in self._feature_order if c not in df.columns]
67
+ if missing:
68
+ raise ValueError(f"Missing required features for model: {missing}")
69
+ return df[self._feature_order].copy()
70
+ # default: use everything present
71
+ return df.copy()
72
+
73
+ def run(self, df: Optional[pd.DataFrame]) -> pd.DataFrame:
74
+ """
75
+ If df is None, returns an empty DataFrame.
76
+ """
77
  self._ensure_loaded()
78
+ if df is None or len(df) == 0:
79
+ return pd.DataFrame()
80
+
81
+ X = self._select_features(df)
82
+ model = self._model
83
+
84
+ # classification with probabilities preferred
85
+ if hasattr(model, "predict_proba"):
86
+ preds = model.predict_proba(X)[:, -1]
87
+ elif hasattr(model, "decision_function"):
88
+ # fallback: map decision function to a score
89
+ import numpy as np
90
+ raw = model.decision_function(X)
91
+ # simple sigmoid to scale-ish if binary
92
+ preds = 1 / (1 + np.exp(-raw))
93
+ else:
94
+ preds = model.predict(X)
95
+
96
  out = df.copy()
97
+ out[self._pred_col] = preds
98
+ try:
99
+ self.tracer.trace_event("predict", {"rows": len(out)})
100
+ except Exception:
101
+ pass
102
+ return out