ALM_LLM / tools /predict_tool.py
AshenH's picture
Update tools/predict_tool.py
32817e1 verified
raw
history blame
2.47 kB
# space/tools/predict_tool.py
import os
import json
import pandas as pd
import joblib
from typing import Optional, List
from huggingface_hub import hf_hub_download
from utils.config import AppConfig
from utils.tracing import Tracer
class PredictTool:
"""
Loads a sklearn-compatible tabular model from a HF repo and runs predictions.
"""
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self._model = None
self._feature_meta = {}
self._pred_col = "prediction"
self._feature_order: Optional[List[str]] = None
def _ensure_loaded(self):
if self._model is not None:
return
token = os.getenv("HF_TOKEN")
repo = self.cfg.hf_model_repo
model_path = hf_hub_download(repo_id=repo, filename="model.pkl", token=token)
self._model = joblib.load(model_path)
try:
meta_path = hf_hub_download(repo_id=repo, filename="feature_metadata.json", token=token)
with open(meta_path, "r", encoding="utf-8") as f:
self._feature_meta = json.load(f) or {}
except Exception:
self._feature_meta = {}
self._pred_col = self._feature_meta.get("prediction_column", "prediction")
self._feature_order = self._feature_meta.get("feature_order")
def _select_features(self, df: pd.DataFrame) -> pd.DataFrame:
if self._feature_order:
missing = [c for c in self._feature_order if c not in df.columns]
if missing:
raise ValueError(f"Missing required features for model: {missing}")
return df[self._feature_order].copy()
return df.copy()
def run(self, df: Optional[pd.DataFrame]) -> pd.DataFrame:
self._ensure_loaded()
if df is None or len(df) == 0:
return pd.DataFrame()
X = self._select_features(df)
model = self._model
if hasattr(model, "predict_proba"):
preds = model.predict_proba(X)[:, -1]
elif hasattr(model, "decision_function"):
import numpy as np
raw = model.decision_function(X)
preds = 1 / (1 + np.exp(-raw))
else:
preds = model.predict(X)
out = df.copy()
out[self._pred_col] = preds
try:
self.tracer.trace_event("predict", {"rows": len(out)})
except Exception:
pass
return out