File size: 11,364 Bytes
aed2def 9b2b28b aed2def 9b2b28b e4818d5 aed2def 9b2b28b e4818d5 aed2def 9b2b28b aed2def 9b2b28b e4818d5 aed2def 9b2b28b aed2def e4818d5 aed2def e4818d5 9b2b28b aed2def e4818d5 9b2b28b e4818d5 aed2def e4818d5 aed2def e4818d5 aed2def e4818d5 aed2def e4818d5 aed2def e4818d5 aed2def e4818d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
# space/tools/explain_tool.py
import os
import io
import json
import base64
import logging
from typing import Dict, Optional
import shap
import pandas as pd
import matplotlib
matplotlib.use('Agg') # Non-interactive backend
import matplotlib.pyplot as plt
import joblib
from huggingface_hub import hf_hub_download
from utils.config import AppConfig
from utils.tracing import Tracer
logger = logging.getLogger(__name__)
# Constants
MAX_SAMPLE_SIZE = 1000
MIN_SAMPLE_SIZE = 10
DEFAULT_SAMPLE_SIZE = 500
MAX_IMAGE_SIZE_MB = 5
class ExplainToolError(Exception):
"""Custom exception for explanation tool errors."""
pass
class ExplainTool:
"""
Generates SHAP-based model explanations with global visualizations.
CPU-friendly with sampling for large datasets.
"""
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self._model = None
self._feature_order = None
logger.info("ExplainTool initialized (lazy loading)")
def _ensure_model(self):
"""Lazy load model and metadata from HuggingFace."""
if self._model is not None:
return
try:
token = os.getenv("HF_TOKEN")
repo = self.cfg.hf_model_repo
if not repo:
raise ExplainToolError("HF_MODEL_REPO not configured")
logger.info(f"Loading model for explanations from: {repo}")
# Download and load model
try:
model_path = hf_hub_download(
repo_id=repo,
filename="model.pkl",
token=token
)
self._model = joblib.load(model_path)
logger.info(f"Model loaded: {type(self._model).__name__}")
except Exception as e:
raise ExplainToolError(f"Failed to load model: {e}") from e
# Load feature metadata
try:
meta_path = hf_hub_download(
repo_id=repo,
filename="feature_metadata.json",
token=token
)
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f) or {}
self._feature_order = meta.get("feature_order")
logger.info(f"Loaded feature order: {len(self._feature_order or [])} features")
except Exception as e:
logger.warning(f"Could not load feature metadata: {e}")
self._feature_order = None
except ExplainToolError:
raise
except Exception as e:
raise ExplainToolError(f"Model initialization failed: {e}") from e
def _validate_data(self, df: pd.DataFrame) -> tuple[bool, str]:
"""
Validate input dataframe.
Returns (is_valid, error_message).
"""
if df is None or df.empty:
return False, "Input dataframe is empty"
if len(df.columns) == 0:
return False, "Dataframe has no columns"
return True, ""
def _prepare_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Prepare feature matrix for SHAP analysis.
Selects and orders features according to model expectations.
"""
if self._feature_order:
# Use specified feature order
available_features = [col for col in self._feature_order if col in df.columns]
missing_features = [col for col in self._feature_order if col not in df.columns]
if missing_features:
logger.warning(
f"Missing {len(missing_features)} features for explanation: "
f"{missing_features[:5]}"
)
if not available_features:
raise ExplainToolError(
f"No required features found in dataframe. "
f"Required: {self._feature_order}, "
f"Available: {list(df.columns)}"
)
X = df[available_features].copy()
logger.info(f"Using {len(available_features)} features for explanation")
else:
# Use all columns
X = df.copy()
logger.warning("No feature order specified - using all columns")
# Remove non-numeric columns
numeric_cols = X.select_dtypes(include=['number']).columns
if len(numeric_cols) < len(X.columns):
dropped = set(X.columns) - set(numeric_cols)
logger.warning(f"Dropping {len(dropped)} non-numeric columns: {list(dropped)[:5]}")
X = X[numeric_cols]
if X.empty or len(X.columns) == 0:
raise ExplainToolError("No numeric features available for explanation")
return X
def _sample_data(self, X: pd.DataFrame, sample_size: int = DEFAULT_SAMPLE_SIZE) -> pd.DataFrame:
"""
Sample data for SHAP analysis to keep computation manageable.
"""
n = len(X)
if n <= MIN_SAMPLE_SIZE:
logger.info(f"Using all {n} rows (below minimum sample size)")
return X
# Determine sample size
target_size = min(sample_size, MAX_SAMPLE_SIZE)
target_size = max(target_size, MIN_SAMPLE_SIZE)
if n <= target_size:
logger.info(f"Using all {n} rows (below target sample size)")
return X
# Stratified sampling if possible
try:
sample = X.sample(n=target_size, random_state=42)
logger.info(f"Sampled {target_size} rows from {n} total")
return sample
except Exception as e:
logger.warning(f"Sampling failed: {e}, using head()")
return X.head(target_size)
@staticmethod
def _to_data_uri(fig) -> str:
"""
Convert matplotlib figure to base64 data URI.
Includes size validation.
"""
try:
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
plt.close(fig)
buf.seek(0)
# Check size
size_mb = len(buf.getvalue()) / (1024 * 1024)
if size_mb > MAX_IMAGE_SIZE_MB:
logger.warning(f"Generated image is large: {size_mb:.2f} MB")
data_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
logger.debug(f"Generated data URI of size: {len(data_uri)} chars")
return data_uri
except Exception as e:
logger.error(f"Failed to convert figure to data URI: {e}")
raise ExplainToolError(f"Image conversion failed: {e}") from e
def _generate_shap_values(self, X: pd.DataFrame) -> shap.Explanation:
"""
Generate SHAP values for the sample.
"""
try:
logger.info("Creating SHAP explainer...")
explainer = shap.Explainer(self._model, X)
logger.info("Computing SHAP values...")
shap_values = explainer(X)
logger.info(f"SHAP values computed: shape={shap_values.values.shape}")
return shap_values
except Exception as e:
raise ExplainToolError(f"SHAP computation failed: {e}") from e
def _create_bar_plot(self, shap_values: shap.Explanation) -> str:
"""Create global feature importance bar plot."""
try:
logger.info("Creating bar plot...")
fig = plt.figure(figsize=(10, 6))
shap.plots.bar(shap_values, show=False, max_display=20)
plt.title("Feature Importance (Global)", fontsize=14, pad=20)
plt.xlabel("Mean |SHAP value|", fontsize=12)
plt.tight_layout()
uri = self._to_data_uri(fig)
logger.info("Bar plot created successfully")
return uri
except Exception as e:
logger.error(f"Bar plot creation failed: {e}")
# Return empty data URI rather than failing completely
return ""
def _create_beeswarm_plot(self, shap_values: shap.Explanation) -> str:
"""Create beeswarm plot showing feature effects."""
try:
logger.info("Creating beeswarm plot...")
fig = plt.figure(figsize=(10, 8))
shap.plots.beeswarm(shap_values, show=False, max_display=20)
plt.title("Feature Effects Distribution", fontsize=14, pad=20)
plt.tight_layout()
uri = self._to_data_uri(fig)
logger.info("Beeswarm plot created successfully")
return uri
except Exception as e:
logger.error(f"Beeswarm plot creation failed: {e}")
return ""
def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
"""
Generate SHAP explanations for input data.
Args:
df: Input dataframe with features
Returns:
Dictionary mapping plot names to base64 data URIs
Raises:
ExplainToolError: If explanation generation fails
"""
try:
# Validate input
is_valid, error_msg = self._validate_data(df)
if not is_valid:
logger.warning(f"Invalid input: {error_msg}")
return {}
# Ensure model is loaded
self._ensure_model()
# Prepare features
X = self._prepare_features(df)
logger.info(f"Prepared features: {X.shape}")
# Sample data for efficiency
sample = self._sample_data(X)
# Generate SHAP values
shap_values = self._generate_shap_values(sample)
# Create visualizations
result = {}
# Bar plot (feature importance)
bar_uri = self._create_bar_plot(shap_values)
if bar_uri:
result["global_bar"] = bar_uri
# Beeswarm plot (feature effects)
bee_uri = self._create_beeswarm_plot(shap_values)
if bee_uri:
result["beeswarm"] = bee_uri
# Log success
logger.info(f"Generated {len(result)} explanation visualizations")
if self.tracer:
self.tracer.trace_event("explain", {
"rows": len(sample),
"features": len(X.columns),
"visualizations": len(result)
})
return result
except ExplainToolError:
raise
except Exception as e:
error_msg = f"Explanation generation failed: {str(e)}"
logger.error(error_msg)
if self.tracer:
self.tracer.trace_event("explain_error", {"error": error_msg})
raise ExplainToolError(error_msg) from e |