|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
from tools.sql_tool import SQLTool |
|
|
from tools.predict_tool import PredictTool |
|
|
from tools.explain_tool import ExplainTool |
|
|
from tools.report_tool import ReportTool |
|
|
from tools.ts_preprocess import build_timeseries |
|
|
from tools.ts_forecast_tool import TimeseriesForecastTool |
|
|
|
|
|
from utils.tracing import Tracer |
|
|
from utils.config import AppConfig |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
MAX_RESPONSE_LENGTH = 10000 |
|
|
MAX_FORECAST_HORIZON = 365 |
|
|
DEFAULT_FORECAST_HORIZON = 96 |
|
|
|
|
|
|
|
|
llm = None |
|
|
LLM_ID = os.getenv("ORCHESTRATOR_MODEL") |
|
|
if LLM_ID: |
|
|
try: |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
logger.info(f"Loading orchestrator model: {LLM_ID}") |
|
|
_tok = AutoTokenizer.from_pretrained(LLM_ID) |
|
|
_mdl = AutoModelForCausalLM.from_pretrained(LLM_ID) |
|
|
llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512) |
|
|
logger.info("Orchestrator model loaded successfully") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load orchestrator model: {e}. Using fallback planner.") |
|
|
llm = None |
|
|
|
|
|
|
|
|
try: |
|
|
cfg = AppConfig.from_env() |
|
|
tracer = Tracer.from_env() |
|
|
|
|
|
sql_tool = SQLTool(cfg, tracer) |
|
|
predict_tool = PredictTool(cfg, tracer) |
|
|
explain_tool = ExplainTool(cfg, tracer) |
|
|
report_tool = ReportTool(cfg, tracer) |
|
|
ts_tool = TimeseriesForecastTool(cfg, tracer) |
|
|
|
|
|
logger.info("All tools initialized successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize application: {e}") |
|
|
raise |
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"You are an analytical assistant for tabular data. " |
|
|
"Decide which tools to call in order: " |
|
|
"1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document) 5) Forecast (Granite TTM). " |
|
|
"Always disclose the steps taken." |
|
|
) |
|
|
|
|
|
|
|
|
def validate_message(message: str) -> Tuple[bool, str]: |
|
|
"""Validate user input message.""" |
|
|
if not message or not message.strip(): |
|
|
return False, "Please enter a valid question." |
|
|
|
|
|
if len(message) > MAX_RESPONSE_LENGTH: |
|
|
return False, f"Message too long. Please limit to {MAX_RESPONSE_LENGTH} characters." |
|
|
|
|
|
|
|
|
suspicious_patterns = [ |
|
|
r';\s*drop\s+table', |
|
|
r';\s*delete\s+from', |
|
|
r';\s*truncate', |
|
|
r'union\s+select.*from', |
|
|
r'exec\s*\(', |
|
|
r'execute\s*\(' |
|
|
] |
|
|
|
|
|
import re |
|
|
msg_lower = message.lower() |
|
|
for pattern in suspicious_patterns: |
|
|
if re.search(pattern, msg_lower): |
|
|
logger.warning(f"Suspicious SQL pattern detected: {pattern}") |
|
|
return False, "Query contains potentially unsafe patterns. Please rephrase." |
|
|
|
|
|
return True, "" |
|
|
|
|
|
|
|
|
def plan_actions(message: str) -> dict: |
|
|
""" |
|
|
Determine which tools to execute based on the user message. |
|
|
Uses LLM if available, otherwise falls back to heuristics. |
|
|
""" |
|
|
if llm is not None: |
|
|
prompt = ( |
|
|
f"{SYSTEM_PROMPT}\nUser: {message}\n" |
|
|
"Return JSON with fields: steps (array subset of ['sql','predict','explain','report','forecast']), rationale." |
|
|
) |
|
|
try: |
|
|
out = llm(prompt)[0]["generated_text"] |
|
|
last = out.split("\n")[-1].strip() |
|
|
obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):]) |
|
|
|
|
|
if isinstance(obj, dict) and "steps" in obj: |
|
|
|
|
|
valid_steps = {'sql', 'predict', 'explain', 'report', 'forecast'} |
|
|
obj["steps"] = [s for s in obj["steps"] if s in valid_steps] |
|
|
if obj["steps"]: |
|
|
logger.info(f"LLM plan: {obj['steps']}") |
|
|
return obj |
|
|
except json.JSONDecodeError as e: |
|
|
logger.warning(f"Failed to parse LLM output as JSON: {e}") |
|
|
except Exception as e: |
|
|
logger.warning(f"LLM planning failed: {e}") |
|
|
|
|
|
|
|
|
m = message.lower() |
|
|
steps = [] |
|
|
|
|
|
|
|
|
if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", |
|
|
"kpi", "data", "retrieve", "fetch", "list", "view"]): |
|
|
steps.append("sql") |
|
|
|
|
|
|
|
|
if any(k in m for k in ["predict", "score", "risk", "propensity", "probability", |
|
|
"classification", "regression"]): |
|
|
steps.append("predict") |
|
|
if "sql" not in steps: |
|
|
steps.insert(0, "sql") |
|
|
|
|
|
|
|
|
if any(k in m for k in ["why", "explain", "shap", "feature", "attribution", |
|
|
"importance", "interpret"]): |
|
|
steps.append("explain") |
|
|
if "predict" not in steps: |
|
|
steps.insert(0, "predict") |
|
|
if "sql" not in steps: |
|
|
steps.insert(0, "sql") |
|
|
|
|
|
|
|
|
if any(k in m for k in ["report", "download", "pdf", "summary", "document", "export"]): |
|
|
steps.append("report") |
|
|
|
|
|
|
|
|
if any(k in m for k in ["forecast", "next", "horizon", "granite", "predict future", |
|
|
"time series", "timeseries"]): |
|
|
steps.append("forecast") |
|
|
if "sql" not in steps: |
|
|
steps.insert(0, "sql") |
|
|
|
|
|
|
|
|
if not steps: |
|
|
steps = ["sql"] |
|
|
|
|
|
rationale = f"Rule-based plan based on keywords: {', '.join(steps)}" |
|
|
logger.info(f"Heuristic plan: {steps}") |
|
|
return {"steps": steps, "rationale": rationale} |
|
|
|
|
|
|
|
|
def run_agent( |
|
|
message: str, |
|
|
hitl_decision: str = "Approve", |
|
|
reviewer_note: str = "" |
|
|
) -> Tuple[str, pd.DataFrame]: |
|
|
""" |
|
|
Main agent execution function. |
|
|
|
|
|
Args: |
|
|
message: User query |
|
|
hitl_decision: Human-in-the-loop decision |
|
|
reviewer_note: Optional review notes |
|
|
|
|
|
Returns: |
|
|
Tuple of (response_text, preview_dataframe) |
|
|
""" |
|
|
try: |
|
|
|
|
|
is_valid, error_msg = validate_message(message) |
|
|
if not is_valid: |
|
|
logger.warning(f"Invalid message: {error_msg}") |
|
|
return f"β **Error:** {error_msg}", pd.DataFrame() |
|
|
|
|
|
tracer.trace_event("user_message", {"message": message[:500]}) |
|
|
|
|
|
|
|
|
try: |
|
|
plan = plan_actions(message) |
|
|
tracer.trace_event("plan", plan) |
|
|
except Exception as e: |
|
|
logger.error(f"Planning failed: {e}") |
|
|
return f"β **Planning Error:** Unable to create execution plan. {str(e)}", pd.DataFrame() |
|
|
|
|
|
|
|
|
sql_df = None |
|
|
predict_df = None |
|
|
explain_imgs = {} |
|
|
artifacts = {} |
|
|
ts_forecast_df = None |
|
|
errors = [] |
|
|
|
|
|
|
|
|
if "sql" in plan["steps"]: |
|
|
try: |
|
|
sql_df = sql_tool.run(message) |
|
|
if isinstance(sql_df, pd.DataFrame): |
|
|
artifacts["sql_rows"] = len(sql_df) |
|
|
logger.info(f"SQL returned {len(sql_df)} rows") |
|
|
else: |
|
|
errors.append("SQL query returned no data") |
|
|
except Exception as e: |
|
|
error_msg = f"SQL execution failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
errors.append(error_msg) |
|
|
|
|
|
|
|
|
if "predict" in plan["steps"]: |
|
|
try: |
|
|
if sql_df is not None and not sql_df.empty: |
|
|
predict_df = predict_tool.run(sql_df) |
|
|
if isinstance(predict_df, pd.DataFrame): |
|
|
artifacts["predict_rows"] = len(predict_df) |
|
|
logger.info(f"Predictions generated for {len(predict_df)} rows") |
|
|
else: |
|
|
errors.append("Prediction skipped: no data available") |
|
|
except Exception as e: |
|
|
error_msg = f"Prediction failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
errors.append(error_msg) |
|
|
|
|
|
|
|
|
ts_df = None |
|
|
if sql_df is not None and not sql_df.empty: |
|
|
try: |
|
|
ts_df = build_timeseries(sql_df) |
|
|
logger.info(f"Time series built with {len(ts_df)} records") |
|
|
except Exception as e: |
|
|
logger.info(f"Time series preprocessing skipped: {e}") |
|
|
|
|
|
|
|
|
|
|
|
if "forecast" in plan["steps"]: |
|
|
if ts_df is not None and not ts_df.empty: |
|
|
try: |
|
|
|
|
|
agg = ts_df.groupby("timestamp", as_index=True)["portfolio_value"].sum().sort_index() |
|
|
|
|
|
if len(agg) < 2: |
|
|
errors.append("Insufficient time series data for forecasting (need at least 2 points)") |
|
|
else: |
|
|
|
|
|
horizon = min(DEFAULT_FORECAST_HORIZON, MAX_FORECAST_HORIZON) |
|
|
ts_forecast_df = ts_tool.zeroshot_forecast(agg, horizon=horizon) |
|
|
|
|
|
if isinstance(ts_forecast_df, pd.DataFrame): |
|
|
if "error" in ts_forecast_df.columns: |
|
|
errors.append(f"Forecast error: {ts_forecast_df['error'].iloc[0]}") |
|
|
ts_forecast_df = None |
|
|
else: |
|
|
artifacts["forecast_horizon"] = len(ts_forecast_df) |
|
|
logger.info(f"Forecast generated for {len(ts_forecast_df)} periods") |
|
|
except Exception as e: |
|
|
error_msg = f"Forecasting failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
errors.append(error_msg) |
|
|
else: |
|
|
errors.append("Forecast skipped: no suitable time series data") |
|
|
|
|
|
|
|
|
if "explain" in plan["steps"]: |
|
|
try: |
|
|
explain_data = predict_df if predict_df is not None else sql_df |
|
|
if explain_data is not None and not explain_data.empty: |
|
|
explain_imgs = explain_tool.run(explain_data) |
|
|
artifacts["explain_charts"] = len(explain_imgs) |
|
|
logger.info(f"Generated {len(explain_imgs)} explanation charts") |
|
|
else: |
|
|
errors.append("Explanation skipped: no data available") |
|
|
except Exception as e: |
|
|
error_msg = f"Explanation failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
errors.append(error_msg) |
|
|
|
|
|
|
|
|
report_link = None |
|
|
if "report" in plan["steps"]: |
|
|
try: |
|
|
forecast_preview = ts_forecast_df.head(50) if isinstance(ts_forecast_df, pd.DataFrame) else None |
|
|
report_link = report_tool.render_and_save( |
|
|
user_query=message, |
|
|
sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None, |
|
|
predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else forecast_preview, |
|
|
explain_images=explain_imgs, |
|
|
plan=plan, |
|
|
) |
|
|
logger.info(f"Report generated: {report_link}") |
|
|
except Exception as e: |
|
|
error_msg = f"Report generation failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
errors.append(error_msg) |
|
|
|
|
|
|
|
|
tracer.trace_event("hitl", { |
|
|
"message": message[:500], |
|
|
"decision": hitl_decision, |
|
|
"reviewer_note": reviewer_note[:500] if reviewer_note else "", |
|
|
"artifacts": artifacts, |
|
|
"plan": plan, |
|
|
"errors": errors, |
|
|
}) |
|
|
|
|
|
|
|
|
response = f"**Plan:** {', '.join(plan['steps'])}\n\n**Rationale:** {plan['rationale']}\n\n" |
|
|
|
|
|
|
|
|
if artifacts: |
|
|
response += "**Results:**\n" |
|
|
if "sql_rows" in artifacts: |
|
|
response += f"- SQL query returned {artifacts['sql_rows']} rows\n" |
|
|
if "predict_rows" in artifacts: |
|
|
response += f"- Generated predictions for {artifacts['predict_rows']} rows\n" |
|
|
if "forecast_horizon" in artifacts: |
|
|
response += f"- Forecast generated for {artifacts['forecast_horizon']} periods\n" |
|
|
if "explain_charts" in artifacts: |
|
|
response += f"- Created {artifacts['explain_charts']} explanation charts\n" |
|
|
response += "\n" |
|
|
|
|
|
|
|
|
if report_link: |
|
|
response += f"π **Report:** {report_link}\n\n" |
|
|
|
|
|
|
|
|
if tracer.trace_url: |
|
|
response += f"π **Trace:** {tracer.trace_url}\n\n" |
|
|
|
|
|
|
|
|
if errors: |
|
|
response += "**β οΈ Warnings/Errors:**\n" |
|
|
for err in errors: |
|
|
response += f"- {err}\n" |
|
|
|
|
|
|
|
|
if isinstance(ts_forecast_df, pd.DataFrame) and not ts_forecast_df.empty: |
|
|
preview_df = ts_forecast_df.head(100) |
|
|
elif isinstance(predict_df, pd.DataFrame) and not predict_df.empty: |
|
|
preview_df = predict_df.head(100) |
|
|
elif isinstance(sql_df, pd.DataFrame) and not sql_df.empty: |
|
|
preview_df = sql_df.head(100) |
|
|
else: |
|
|
preview_df = pd.DataFrame({"message": ["No data to display"]}) |
|
|
|
|
|
return response, preview_df |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Unexpected error in agent execution: {str(e)}" |
|
|
logger.exception(error_msg) |
|
|
tracer.trace_event("error", {"message": error_msg}) |
|
|
return f"β **Critical Error:** {error_msg}", pd.DataFrame() |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Tabular Agentic XAI") as demo: |
|
|
gr.Markdown(""" |
|
|
# π€ Tabular Agentic XAI (Enterprise Edition) |
|
|
|
|
|
An intelligent assistant for analyzing tabular data with ML predictions, explanations, and time-series forecasting. |
|
|
|
|
|
**Capabilities:** |
|
|
- π SQL queries and data retrieval |
|
|
- π― ML predictions with confidence scores |
|
|
- π SHAP-based model explanations |
|
|
- π Time-series forecasting with Granite TTM |
|
|
- π Automated report generation |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
msg = gr.Textbox( |
|
|
label="Ask your question", |
|
|
placeholder="e.g., Show me the top 10 customers by revenue, predict churn risk, forecast next quarter...", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
hitl = gr.Radio( |
|
|
["Approve", "Needs Changes"], |
|
|
value="Approve", |
|
|
label="Human Review", |
|
|
info="Review the planned actions before execution" |
|
|
) |
|
|
note = gr.Textbox( |
|
|
label="Reviewer note (optional)", |
|
|
placeholder="Add any review comments...", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
out_md = gr.Markdown(label="Response") |
|
|
out_df = gr.Dataframe( |
|
|
interactive=False, |
|
|
label="Data Preview (max 100 rows)", |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
ask = gr.Button("π Run Analysis", variant="primary") |
|
|
clear = gr.Button("π Clear") |
|
|
|
|
|
ask.click( |
|
|
run_agent, |
|
|
inputs=[msg, hitl, note], |
|
|
outputs=[out_md, out_df] |
|
|
) |
|
|
|
|
|
clear.click( |
|
|
lambda: ("", "Approve", "", "", pd.DataFrame()), |
|
|
outputs=[msg, hitl, note, out_md, out_df] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**Tips:** |
|
|
- Be specific in your queries for better results |
|
|
- Use natural language - the system will interpret your intent |
|
|
- Review the execution plan before approving |
|
|
- Check the trace link for detailed execution logs |
|
|
""") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logger.info("Starting Gradio application...") |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True |
|
|
) |