File size: 16,618 Bytes
68c51bb af53f4b da25b2a af53f4b da25b2a af53f4b 68c51bb af53f4b da25b2a af53f4b da25b2a af53f4b da25b2a af53f4b da25b2a c02152a af53f4b 68c51bb af53f4b da25b2a af53f4b 68c51bb af53f4b da25b2a af53f4b da25b2a af53f4b da25b2a 68c51bb da25b2a 68c51bb da25b2a af53f4b da25b2a af53f4b da25b2a af53f4b da25b2a af53f4b da25b2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 |
# space/app.py
import os
import json
import logging
import gradio as gr
import pandas as pd
from typing import Optional, Tuple
from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from tools.ts_preprocess import build_timeseries
from tools.ts_forecast_tool import TimeseriesForecastTool
from utils.tracing import Tracer
from utils.config import AppConfig
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Constants
MAX_RESPONSE_LENGTH = 10000
MAX_FORECAST_HORIZON = 365
DEFAULT_FORECAST_HORIZON = 96
# Optional LLM for planning
llm = None
LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
if LLM_ID:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
logger.info(f"Loading orchestrator model: {LLM_ID}")
_tok = AutoTokenizer.from_pretrained(LLM_ID)
_mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
logger.info("Orchestrator model loaded successfully")
except Exception as e:
logger.warning(f"Failed to load orchestrator model: {e}. Using fallback planner.")
llm = None
# Initialize configuration and tools
try:
cfg = AppConfig.from_env()
tracer = Tracer.from_env()
sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)
ts_tool = TimeseriesForecastTool(cfg, tracer)
logger.info("All tools initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize application: {e}")
raise
SYSTEM_PROMPT = (
"You are an analytical assistant for tabular data. "
"Decide which tools to call in order: "
"1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document) 5) Forecast (Granite TTM). "
"Always disclose the steps taken."
)
def validate_message(message: str) -> Tuple[bool, str]:
"""Validate user input message."""
if not message or not message.strip():
return False, "Please enter a valid question."
if len(message) > MAX_RESPONSE_LENGTH:
return False, f"Message too long. Please limit to {MAX_RESPONSE_LENGTH} characters."
# Basic SQL injection pattern detection
suspicious_patterns = [
r';\s*drop\s+table',
r';\s*delete\s+from',
r';\s*truncate',
r'union\s+select.*from',
r'exec\s*\(',
r'execute\s*\('
]
import re
msg_lower = message.lower()
for pattern in suspicious_patterns:
if re.search(pattern, msg_lower):
logger.warning(f"Suspicious SQL pattern detected: {pattern}")
return False, "Query contains potentially unsafe patterns. Please rephrase."
return True, ""
def plan_actions(message: str) -> dict:
"""
Determine which tools to execute based on the user message.
Uses LLM if available, otherwise falls back to heuristics.
"""
if llm is not None:
prompt = (
f"{SYSTEM_PROMPT}\nUser: {message}\n"
"Return JSON with fields: steps (array subset of ['sql','predict','explain','report','forecast']), rationale."
)
try:
out = llm(prompt)[0]["generated_text"]
last = out.split("\n")[-1].strip()
obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
if isinstance(obj, dict) and "steps" in obj:
# Validate steps
valid_steps = {'sql', 'predict', 'explain', 'report', 'forecast'}
obj["steps"] = [s for s in obj["steps"] if s in valid_steps]
if obj["steps"]:
logger.info(f"LLM plan: {obj['steps']}")
return obj
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse LLM output as JSON: {e}")
except Exception as e:
logger.warning(f"LLM planning failed: {e}")
# Fallback heuristic planning
m = message.lower()
steps = []
# SQL keywords
if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query",
"kpi", "data", "retrieve", "fetch", "list", "view"]):
steps.append("sql")
# Prediction keywords
if any(k in m for k in ["predict", "score", "risk", "propensity", "probability",
"classification", "regression"]):
steps.append("predict")
if "sql" not in steps:
steps.insert(0, "sql") # Need data first
# Explanation keywords
if any(k in m for k in ["why", "explain", "shap", "feature", "attribution",
"importance", "interpret"]):
steps.append("explain")
if "predict" not in steps:
steps.insert(0, "predict")
if "sql" not in steps:
steps.insert(0, "sql")
# Report keywords
if any(k in m for k in ["report", "download", "pdf", "summary", "document", "export"]):
steps.append("report")
# Forecast keywords
if any(k in m for k in ["forecast", "next", "horizon", "granite", "predict future",
"time series", "timeseries"]):
steps.append("forecast")
if "sql" not in steps:
steps.insert(0, "sql")
# Default to SQL if no steps identified
if not steps:
steps = ["sql"]
rationale = f"Rule-based plan based on keywords: {', '.join(steps)}"
logger.info(f"Heuristic plan: {steps}")
return {"steps": steps, "rationale": rationale}
def run_agent(
message: str,
hitl_decision: str = "Approve",
reviewer_note: str = ""
) -> Tuple[str, pd.DataFrame]:
"""
Main agent execution function.
Args:
message: User query
hitl_decision: Human-in-the-loop decision
reviewer_note: Optional review notes
Returns:
Tuple of (response_text, preview_dataframe)
"""
try:
# Validate input
is_valid, error_msg = validate_message(message)
if not is_valid:
logger.warning(f"Invalid message: {error_msg}")
return f"β **Error:** {error_msg}", pd.DataFrame()
tracer.trace_event("user_message", {"message": message[:500]}) # Limit traced message length
# Plan actions
try:
plan = plan_actions(message)
tracer.trace_event("plan", plan)
except Exception as e:
logger.error(f"Planning failed: {e}")
return f"β **Planning Error:** Unable to create execution plan. {str(e)}", pd.DataFrame()
# Initialize result containers
sql_df = None
predict_df = None
explain_imgs = {}
artifacts = {}
ts_forecast_df = None
errors = []
# Execute SQL step
if "sql" in plan["steps"]:
try:
sql_df = sql_tool.run(message)
if isinstance(sql_df, pd.DataFrame):
artifacts["sql_rows"] = len(sql_df)
logger.info(f"SQL returned {len(sql_df)} rows")
else:
errors.append("SQL query returned no data")
except Exception as e:
error_msg = f"SQL execution failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Execute prediction step
if "predict" in plan["steps"]:
try:
if sql_df is not None and not sql_df.empty:
predict_df = predict_tool.run(sql_df)
if isinstance(predict_df, pd.DataFrame):
artifacts["predict_rows"] = len(predict_df)
logger.info(f"Predictions generated for {len(predict_df)} rows")
else:
errors.append("Prediction skipped: no data available")
except Exception as e:
error_msg = f"Prediction failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Build time series if possible
ts_df = None
if sql_df is not None and not sql_df.empty:
try:
ts_df = build_timeseries(sql_df)
logger.info(f"Time series built with {len(ts_df)} records")
except Exception as e:
logger.info(f"Time series preprocessing skipped: {e}")
# Not always an error - data might not be suitable for TS
# Execute forecast step
if "forecast" in plan["steps"]:
if ts_df is not None and not ts_df.empty:
try:
# Aggregate portfolio value by timestamp
agg = ts_df.groupby("timestamp", as_index=True)["portfolio_value"].sum().sort_index()
if len(agg) < 2:
errors.append("Insufficient time series data for forecasting (need at least 2 points)")
else:
# Validate horizon
horizon = min(DEFAULT_FORECAST_HORIZON, MAX_FORECAST_HORIZON)
ts_forecast_df = ts_tool.zeroshot_forecast(agg, horizon=horizon)
if isinstance(ts_forecast_df, pd.DataFrame):
if "error" in ts_forecast_df.columns:
errors.append(f"Forecast error: {ts_forecast_df['error'].iloc[0]}")
ts_forecast_df = None
else:
artifacts["forecast_horizon"] = len(ts_forecast_df)
logger.info(f"Forecast generated for {len(ts_forecast_df)} periods")
except Exception as e:
error_msg = f"Forecasting failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
else:
errors.append("Forecast skipped: no suitable time series data")
# Execute explanation step
if "explain" in plan["steps"]:
try:
explain_data = predict_df if predict_df is not None else sql_df
if explain_data is not None and not explain_data.empty:
explain_imgs = explain_tool.run(explain_data)
artifacts["explain_charts"] = len(explain_imgs)
logger.info(f"Generated {len(explain_imgs)} explanation charts")
else:
errors.append("Explanation skipped: no data available")
except Exception as e:
error_msg = f"Explanation failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Execute report generation
report_link = None
if "report" in plan["steps"]:
try:
forecast_preview = ts_forecast_df.head(50) if isinstance(ts_forecast_df, pd.DataFrame) else None
report_link = report_tool.render_and_save(
user_query=message,
sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else forecast_preview,
explain_images=explain_imgs,
plan=plan,
)
logger.info(f"Report generated: {report_link}")
except Exception as e:
error_msg = f"Report generation failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Log human-in-the-loop decision
tracer.trace_event("hitl", {
"message": message[:500],
"decision": hitl_decision,
"reviewer_note": reviewer_note[:500] if reviewer_note else "",
"artifacts": artifacts,
"plan": plan,
"errors": errors,
})
# Compose response
response = f"**Plan:** {', '.join(plan['steps'])}\n\n**Rationale:** {plan['rationale']}\n\n"
# Add artifacts info
if artifacts:
response += "**Results:**\n"
if "sql_rows" in artifacts:
response += f"- SQL query returned {artifacts['sql_rows']} rows\n"
if "predict_rows" in artifacts:
response += f"- Generated predictions for {artifacts['predict_rows']} rows\n"
if "forecast_horizon" in artifacts:
response += f"- Forecast generated for {artifacts['forecast_horizon']} periods\n"
if "explain_charts" in artifacts:
response += f"- Created {artifacts['explain_charts']} explanation charts\n"
response += "\n"
# Add report link
if report_link:
response += f"π **Report:** {report_link}\n\n"
# Add trace URL
if tracer.trace_url:
response += f"π **Trace:** {tracer.trace_url}\n\n"
# Add errors if any
if errors:
response += "**β οΈ Warnings/Errors:**\n"
for err in errors:
response += f"- {err}\n"
# Determine preview dataframe
if isinstance(ts_forecast_df, pd.DataFrame) and not ts_forecast_df.empty:
preview_df = ts_forecast_df.head(100)
elif isinstance(predict_df, pd.DataFrame) and not predict_df.empty:
preview_df = predict_df.head(100)
elif isinstance(sql_df, pd.DataFrame) and not sql_df.empty:
preview_df = sql_df.head(100)
else:
preview_df = pd.DataFrame({"message": ["No data to display"]})
return response, preview_df
except Exception as e:
error_msg = f"Unexpected error in agent execution: {str(e)}"
logger.exception(error_msg)
tracer.trace_event("error", {"message": error_msg})
return f"β **Critical Error:** {error_msg}", pd.DataFrame()
# Gradio Interface
with gr.Blocks(title="Tabular Agentic XAI") as demo:
gr.Markdown("""
# π€ Tabular Agentic XAI (Enterprise Edition)
An intelligent assistant for analyzing tabular data with ML predictions, explanations, and time-series forecasting.
**Capabilities:**
- π SQL queries and data retrieval
- π― ML predictions with confidence scores
- π SHAP-based model explanations
- π Time-series forecasting with Granite TTM
- π Automated report generation
""")
with gr.Row():
msg = gr.Textbox(
label="Ask your question",
placeholder="e.g., Show me the top 10 customers by revenue, predict churn risk, forecast next quarter...",
lines=3
)
with gr.Row():
hitl = gr.Radio(
["Approve", "Needs Changes"],
value="Approve",
label="Human Review",
info="Review the planned actions before execution"
)
note = gr.Textbox(
label="Reviewer note (optional)",
placeholder="Add any review comments...",
lines=2
)
out_md = gr.Markdown(label="Response")
out_df = gr.Dataframe(
interactive=False,
label="Data Preview (max 100 rows)",
wrap=True
)
with gr.Row():
ask = gr.Button("π Run Analysis", variant="primary")
clear = gr.Button("π Clear")
ask.click(
run_agent,
inputs=[msg, hitl, note],
outputs=[out_md, out_df]
)
clear.click(
lambda: ("", "Approve", "", "", pd.DataFrame()),
outputs=[msg, hitl, note, out_md, out_df]
)
gr.Markdown("""
---
**Tips:**
- Be specific in your queries for better results
- Use natural language - the system will interpret your intent
- Review the execution plan before approving
- Check the trace link for detailed execution logs
""")
if __name__ == "__main__":
logger.info("Starting Gradio application...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
) |