ALM_LLM / tools /explain_tool.py
AshenH's picture
Update tools/explain_tool.py
1470e7a verified
raw
history blame
2.43 kB
# space/tools/explain_tool.py
import os
import io
import json
import base64
from typing import Dict, Optional
import shap
import pandas as pd
import matplotlib.pyplot as plt
import joblib
from huggingface_hub import hf_hub_download
from utils.config import AppConfig
from utils.tracing import Tracer
class ExplainTool:
"""
Generates global SHAP visualizations for a sample of rows (CPU-friendly).
"""
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self._model = None
self._feature_order = None
def _ensure_model(self):
if self._model is not None:
return
token = os.getenv("HF_TOKEN")
repo = self.cfg.hf_model_repo
model_path = hf_hub_download(repo_id=repo, filename="model.pkl", token=token)
self._model = joblib.load(model_path)
try:
meta_path = hf_hub_download(repo_id=repo, filename="feature_metadata.json", token=token)
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f) or {}
self._feature_order = meta.get("feature_order")
except Exception:
self._feature_order = None
@staticmethod
def _to_data_uri(fig) -> str:
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
plt.close(fig)
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
self._ensure_model()
if df is None or len(df) == 0:
return {}
if self._feature_order:
cols = [c for c in self._feature_order if c in df.columns]
X = df[cols].copy()
else:
X = df.copy()
n = min(len(X), 500)
sample = X.sample(n, random_state=42) if len(X) > n else X
explainer = shap.Explainer(self._model, sample)
sv = explainer(sample)
fig_bar = plt.figure()
shap.plots.bar(sv, show=False)
bar_uri = self._to_data_uri(fig_bar)
fig_bee = plt.figure()
shap.plots.beeswarm(sv, show=False)
bee_uri = self._to_data_uri(fig_bee)
try:
self.tracer.trace_event("explain", {"rows": int(n)})
except Exception:
pass
return {"global_bar": bar_uri, "beeswarm": bee_uri}