File size: 5,080 Bytes
af53f4b
 
 
 
 
 
 
 
 
 
 
c02152a
af53f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c02152a
 
 
 
 
 
af53f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import json
import gradio as gr
import pandas as pd

from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from utils.tracing import Tracer
from utils.config import AppConfig
from tools.ts_forecast_tool import TimeseriesForecastTool

# Optional tiny CPU LLM for planning (can be disabled by not setting ORCHESTRATOR_MODEL)
llm = None
LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
if LLM_ID:
    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
        _tok = AutoTokenizer.from_pretrained(LLM_ID)
        _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
        llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
    except Exception:
        llm = None

cfg = AppConfig.from_env()
tracer = Tracer.from_env()

sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)

ts_tool = TimeseriesForecastTool(cfg, tracer,
    context_length=512, forecast_length=96,
    target_cols=["portfolio_value"],          # set after preprocessing
    control_cols=["rate_deposit", "rate_asset"]  # optional exogenous
)

SYSTEM_PROMPT = (
    "You are an analytical assistant for tabular data. "
    "Decide which tools to call in order: "
    "1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document). "
    "Always disclose the steps taken."
)

def plan_actions(message: str):
    if llm is not None:
        prompt = (
            f"{SYSTEM_PROMPT}\nUser: {message}\n"
            "Return JSON with fields: steps (array subset of ['sql','predict','explain','report']), rationale."
        )
        try:
            out = llm(prompt)[0]["generated_text"]
            last = out.split("\n")[-1].strip()
            obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
            if isinstance(obj, dict) and "steps" in obj:
                return obj
        except Exception:
            pass
    # Fallback heuristic:
    m = message.lower()
    steps = []
    if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]): steps.append("sql")
    if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]): steps.append("predict")
    if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]): steps.append("explain")
    if any(k in m for k in ["report", "download", "pdf", "summary"]): steps.append("report")
    if not steps: steps = ["sql"]
    return {"steps": steps, "rationale": "Rule-based plan."}

def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
    tracer.trace_event("user_message", {"message": message})
    plan = plan_actions(message)
    tracer.trace_event("plan", plan)

    sql_df = None
    predict_df = None
    explain_imgs = {}
    artifacts = {}

    if "sql" in plan["steps"]:
        sql_df = sql_tool.run(message)
        artifacts["sql_rows"] = int(len(sql_df)) if isinstance(sql_df, pd.DataFrame) else 0

    if "predict" in plan["steps"]:
        predict_df = predict_tool.run(sql_df)

    if "explain" in plan["steps"]:
        explain_imgs = explain_tool.run(predict_df or sql_df)

    report_link = None
    if "report" in plan["steps"]:
        report_link = report_tool.render_and_save(
            user_query=message,
            sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
            predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else None,
            explain_images=explain_imgs,
            plan=plan,
        )

    tracer.trace_event("hitl", {
        "message": message,
        "decision": hitl_decision,
        "reviewer_note": reviewer_note,
        "artifacts": artifacts,
        "plan": plan,
    })

    response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
    if isinstance(sql_df, pd.DataFrame): response += f"\n**SQL rows:** {len(sql_df)}"
    if isinstance(predict_df, pd.DataFrame): response += f"\n**Predictions rows:** {len(predict_df)}"
    if report_link: response += f"\n**Report:** {report_link}"
    if tracer.trace_url: response += f"\n**Trace:** {tracer.trace_url}"

    preview_df = predict_df if isinstance(predict_df, pd.DataFrame) and len(predict_df) else sql_df
    return response, (preview_df if isinstance(preview_df, pd.DataFrame) else pd.DataFrame())

with gr.Blocks() as demo:
    gr.Markdown("# Tabular Agentic XAI (Free-Tier)")
    with gr.Row():
        msg = gr.Textbox(label="Ask your question")
    with gr.Row():
        hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
        note = gr.Textbox(label="Reviewer note (optional)")
    out_md = gr.Markdown()
    out_df = gr.Dataframe(interactive=False)
    ask = gr.Button("Run")
    ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])

if __name__ == "__main__":
    demo.launch()