File size: 5,080 Bytes
af53f4b c02152a af53f4b c02152a af53f4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import json
import gradio as gr
import pandas as pd
from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from utils.tracing import Tracer
from utils.config import AppConfig
from tools.ts_forecast_tool import TimeseriesForecastTool
# Optional tiny CPU LLM for planning (can be disabled by not setting ORCHESTRATOR_MODEL)
llm = None
LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
if LLM_ID:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
_tok = AutoTokenizer.from_pretrained(LLM_ID)
_mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
except Exception:
llm = None
cfg = AppConfig.from_env()
tracer = Tracer.from_env()
sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)
ts_tool = TimeseriesForecastTool(cfg, tracer,
context_length=512, forecast_length=96,
target_cols=["portfolio_value"], # set after preprocessing
control_cols=["rate_deposit", "rate_asset"] # optional exogenous
)
SYSTEM_PROMPT = (
"You are an analytical assistant for tabular data. "
"Decide which tools to call in order: "
"1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document). "
"Always disclose the steps taken."
)
def plan_actions(message: str):
if llm is not None:
prompt = (
f"{SYSTEM_PROMPT}\nUser: {message}\n"
"Return JSON with fields: steps (array subset of ['sql','predict','explain','report']), rationale."
)
try:
out = llm(prompt)[0]["generated_text"]
last = out.split("\n")[-1].strip()
obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
if isinstance(obj, dict) and "steps" in obj:
return obj
except Exception:
pass
# Fallback heuristic:
m = message.lower()
steps = []
if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]): steps.append("sql")
if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]): steps.append("predict")
if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]): steps.append("explain")
if any(k in m for k in ["report", "download", "pdf", "summary"]): steps.append("report")
if not steps: steps = ["sql"]
return {"steps": steps, "rationale": "Rule-based plan."}
def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
tracer.trace_event("user_message", {"message": message})
plan = plan_actions(message)
tracer.trace_event("plan", plan)
sql_df = None
predict_df = None
explain_imgs = {}
artifacts = {}
if "sql" in plan["steps"]:
sql_df = sql_tool.run(message)
artifacts["sql_rows"] = int(len(sql_df)) if isinstance(sql_df, pd.DataFrame) else 0
if "predict" in plan["steps"]:
predict_df = predict_tool.run(sql_df)
if "explain" in plan["steps"]:
explain_imgs = explain_tool.run(predict_df or sql_df)
report_link = None
if "report" in plan["steps"]:
report_link = report_tool.render_and_save(
user_query=message,
sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else None,
explain_images=explain_imgs,
plan=plan,
)
tracer.trace_event("hitl", {
"message": message,
"decision": hitl_decision,
"reviewer_note": reviewer_note,
"artifacts": artifacts,
"plan": plan,
})
response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
if isinstance(sql_df, pd.DataFrame): response += f"\n**SQL rows:** {len(sql_df)}"
if isinstance(predict_df, pd.DataFrame): response += f"\n**Predictions rows:** {len(predict_df)}"
if report_link: response += f"\n**Report:** {report_link}"
if tracer.trace_url: response += f"\n**Trace:** {tracer.trace_url}"
preview_df = predict_df if isinstance(predict_df, pd.DataFrame) and len(predict_df) else sql_df
return response, (preview_df if isinstance(preview_df, pd.DataFrame) else pd.DataFrame())
with gr.Blocks() as demo:
gr.Markdown("# Tabular Agentic XAI (Free-Tier)")
with gr.Row():
msg = gr.Textbox(label="Ask your question")
with gr.Row():
hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
note = gr.Textbox(label="Reviewer note (optional)")
out_md = gr.Markdown()
out_df = gr.Dataframe(interactive=False)
ask = gr.Button("Run")
ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])
if __name__ == "__main__":
demo.launch()
|