File size: 2,434 Bytes
aed2def 9b2b28b aed2def 9b2b28b aed2def 9b2b28b aed2def 9b2b28b aed2def 9b2b28b aed2def 9b2b28b aed2def 1470e7a aed2def 9b2b28b aed2def 9b2b28b aed2def 1470e7a aed2def 9b2b28b aed2def 1470e7a aed2def 9b2b28b aed2def 9b2b28b aed2def 9b2b28b aed2def 1470e7a aed2def 9b2b28b aed2def 9b2b28b aed2def 9b2b28b aed2def |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
# space/tools/explain_tool.py
import os
import io
import json
import base64
from typing import Dict, Optional
import shap
import pandas as pd
import matplotlib.pyplot as plt
import joblib
from huggingface_hub import hf_hub_download
from utils.config import AppConfig
from utils.tracing import Tracer
class ExplainTool:
"""
Generates global SHAP visualizations for a sample of rows (CPU-friendly).
"""
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self._model = None
self._feature_order = None
def _ensure_model(self):
if self._model is not None:
return
token = os.getenv("HF_TOKEN")
repo = self.cfg.hf_model_repo
model_path = hf_hub_download(repo_id=repo, filename="model.pkl", token=token)
self._model = joblib.load(model_path)
try:
meta_path = hf_hub_download(repo_id=repo, filename="feature_metadata.json", token=token)
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f) or {}
self._feature_order = meta.get("feature_order")
except Exception:
self._feature_order = None
@staticmethod
def _to_data_uri(fig) -> str:
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
plt.close(fig)
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
self._ensure_model()
if df is None or len(df) == 0:
return {}
if self._feature_order:
cols = [c for c in self._feature_order if c in df.columns]
X = df[cols].copy()
else:
X = df.copy()
n = min(len(X), 500)
sample = X.sample(n, random_state=42) if len(X) > n else X
explainer = shap.Explainer(self._model, sample)
sv = explainer(sample)
fig_bar = plt.figure()
shap.plots.bar(sv, show=False)
bar_uri = self._to_data_uri(fig_bar)
fig_bee = plt.figure()
shap.plots.beeswarm(sv, show=False)
bee_uri = self._to_data_uri(fig_bee)
try:
self.tracer.trace_event("explain", {"rows": int(n)})
except Exception:
pass
return {"global_bar": bar_uri, "beeswarm": bee_uri}
|