AshenH commited on
Commit
1470e7a
·
verified ·
1 Parent(s): 32817e1

Update tools/explain_tool.py

Browse files
Files changed (1) hide show
  1. tools/explain_tool.py +5 -27
tools/explain_tool.py CHANGED
@@ -17,8 +17,7 @@ from utils.tracing import Tracer
17
 
18
  class ExplainTool:
19
  """
20
- Generates lightweight global SHAP visualizations (bar + beeswarm) for a sample
21
- of the current DataFrame. Designed to run on CPU in HF Spaces.
22
  """
23
  def __init__(self, cfg: AppConfig, tracer: Tracer):
24
  self.cfg = cfg
@@ -32,20 +31,11 @@ class ExplainTool:
32
  token = os.getenv("HF_TOKEN")
33
  repo = self.cfg.hf_model_repo
34
 
35
- model_path = hf_hub_download(
36
- repo_id=repo,
37
- filename="model.pkl",
38
- token=token
39
- )
40
  self._model = joblib.load(model_path)
41
 
42
- # read optional feature metadata to keep column order consistent
43
  try:
44
- meta_path = hf_hub_download(
45
- repo_id=repo,
46
- filename="feature_metadata.json",
47
- token=token
48
- )
49
  with open(meta_path, "r", encoding="utf-8") as f:
50
  meta = json.load(f) or {}
51
  self._feature_order = meta.get("feature_order")
@@ -61,38 +51,26 @@ class ExplainTool:
61
  return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
62
 
63
  def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
64
- """
65
- Returns dict of {plot_name: data_uri_png}. If df is None/empty, returns {}.
66
- """
67
  self._ensure_model()
68
  if df is None or len(df) == 0:
69
  return {}
70
 
71
- # Select & sample features
72
  if self._feature_order:
73
- missing = [c for c in self._feature_order if c not in df.columns]
74
- if missing:
75
- # best effort: intersect
76
- X = df[[c for c in self._feature_order if c in df.columns]].copy()
77
- else:
78
- X = df[self._feature_order].copy()
79
  else:
80
  X = df.copy()
81
 
82
- # Small sample for speed
83
  n = min(len(X), 500)
84
  sample = X.sample(n, random_state=42) if len(X) > n else X
85
 
86
- # Build explainer and compute SHAP values
87
  explainer = shap.Explainer(self._model, sample)
88
  sv = explainer(sample)
89
 
90
- # --- Global bar plot ---
91
  fig_bar = plt.figure()
92
  shap.plots.bar(sv, show=False)
93
  bar_uri = self._to_data_uri(fig_bar)
94
 
95
- # --- Beeswarm plot ---
96
  fig_bee = plt.figure()
97
  shap.plots.beeswarm(sv, show=False)
98
  bee_uri = self._to_data_uri(fig_bee)
 
17
 
18
  class ExplainTool:
19
  """
20
+ Generates global SHAP visualizations for a sample of rows (CPU-friendly).
 
21
  """
22
  def __init__(self, cfg: AppConfig, tracer: Tracer):
23
  self.cfg = cfg
 
31
  token = os.getenv("HF_TOKEN")
32
  repo = self.cfg.hf_model_repo
33
 
34
+ model_path = hf_hub_download(repo_id=repo, filename="model.pkl", token=token)
 
 
 
 
35
  self._model = joblib.load(model_path)
36
 
 
37
  try:
38
+ meta_path = hf_hub_download(repo_id=repo, filename="feature_metadata.json", token=token)
 
 
 
 
39
  with open(meta_path, "r", encoding="utf-8") as f:
40
  meta = json.load(f) or {}
41
  self._feature_order = meta.get("feature_order")
 
51
  return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
52
 
53
  def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
 
 
 
54
  self._ensure_model()
55
  if df is None or len(df) == 0:
56
  return {}
57
 
 
58
  if self._feature_order:
59
+ cols = [c for c in self._feature_order if c in df.columns]
60
+ X = df[cols].copy()
 
 
 
 
61
  else:
62
  X = df.copy()
63
 
 
64
  n = min(len(X), 500)
65
  sample = X.sample(n, random_state=42) if len(X) > n else X
66
 
 
67
  explainer = shap.Explainer(self._model, sample)
68
  sv = explainer(sample)
69
 
 
70
  fig_bar = plt.figure()
71
  shap.plots.bar(sv, show=False)
72
  bar_uri = self._to_data_uri(fig_bar)
73
 
 
74
  fig_bee = plt.figure()
75
  shap.plots.beeswarm(sv, show=False)
76
  bee_uri = self._to_data_uri(fig_bee)