File size: 2,434 Bytes
aed2def
9b2b28b
 
aed2def
9b2b28b
aed2def
 
 
9b2b28b
aed2def
 
9b2b28b
aed2def
9b2b28b
 
 
aed2def
9b2b28b
aed2def
1470e7a
aed2def
9b2b28b
 
 
 
aed2def
9b2b28b
 
aed2def
 
 
 
 
1470e7a
aed2def
9b2b28b
aed2def
1470e7a
aed2def
 
 
 
 
 
 
 
9b2b28b
aed2def
 
9b2b28b
 
 
aed2def
9b2b28b
aed2def
 
 
 
1470e7a
 
aed2def
 
 
 
 
 
 
 
 
 
 
 
9b2b28b
aed2def
 
 
9b2b28b
aed2def
 
 
 
9b2b28b
aed2def
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# space/tools/explain_tool.py
import os
import io
import json
import base64
from typing import Dict, Optional

import shap
import pandas as pd
import matplotlib.pyplot as plt
import joblib
from huggingface_hub import hf_hub_download

from utils.config import AppConfig
from utils.tracing import Tracer


class ExplainTool:
    """
    Generates global SHAP visualizations for a sample of rows (CPU-friendly).
    """
    def __init__(self, cfg: AppConfig, tracer: Tracer):
        self.cfg = cfg
        self.tracer = tracer
        self._model = None
        self._feature_order = None

    def _ensure_model(self):
        if self._model is not None:
            return
        token = os.getenv("HF_TOKEN")
        repo = self.cfg.hf_model_repo

        model_path = hf_hub_download(repo_id=repo, filename="model.pkl", token=token)
        self._model = joblib.load(model_path)

        try:
            meta_path = hf_hub_download(repo_id=repo, filename="feature_metadata.json", token=token)
            with open(meta_path, "r", encoding="utf-8") as f:
                meta = json.load(f) or {}
            self._feature_order = meta.get("feature_order")
        except Exception:
            self._feature_order = None

    @staticmethod
    def _to_data_uri(fig) -> str:
        buf = io.BytesIO()
        fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
        plt.close(fig)
        buf.seek(0)
        return "data:image/png;base64," + base64.b64encode(buf.read()).decode()

    def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
        self._ensure_model()
        if df is None or len(df) == 0:
            return {}

        if self._feature_order:
            cols = [c for c in self._feature_order if c in df.columns]
            X = df[cols].copy()
        else:
            X = df.copy()

        n = min(len(X), 500)
        sample = X.sample(n, random_state=42) if len(X) > n else X

        explainer = shap.Explainer(self._model, sample)
        sv = explainer(sample)

        fig_bar = plt.figure()
        shap.plots.bar(sv, show=False)
        bar_uri = self._to_data_uri(fig_bar)

        fig_bee = plt.figure()
        shap.plots.beeswarm(sv, show=False)
        bee_uri = self._to_data_uri(fig_bee)

        try:
            self.tracer.trace_event("explain", {"rows": int(n)})
        except Exception:
            pass

        return {"global_bar": bar_uri, "beeswarm": bee_uri}