File size: 6,342 Bytes
68c51bb
af53f4b
 
 
 
 
 
 
 
 
68c51bb
 
 
af53f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68c51bb
c02152a
af53f4b
 
 
68c51bb
af53f4b
 
 
 
 
 
 
68c51bb
af53f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68c51bb
af53f4b
 
 
 
 
 
 
 
 
 
 
 
68c51bb
af53f4b
cd9ff90
 
68c51bb
 
 
 
 
 
 
cd9ff90
 
 
 
af53f4b
68c51bb
 
 
 
 
 
 
 
 
af53f4b
 
 
 
 
 
68c51bb
 
af53f4b
 
 
68c51bb
af53f4b
 
 
 
 
 
 
 
 
 
 
 
68c51bb
af53f4b
 
 
68c51bb
 
af53f4b
 
 
68c51bb
 
 
af53f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# space/app.py
import os
import json
import gradio as gr
import pandas as pd

from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from tools.ts_preprocess import build_timeseries
from tools.ts_forecast_tool import TimeseriesForecastTool

from utils.tracing import Tracer
from utils.config import AppConfig

# Optional tiny CPU LLM for planning (can be disabled by not setting ORCHESTRATOR_MODEL)
llm = None
LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
if LLM_ID:
    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
        _tok = AutoTokenizer.from_pretrained(LLM_ID)
        _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
        llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
    except Exception:
        llm = None

cfg = AppConfig.from_env()
tracer = Tracer.from_env()

sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)
ts_tool = TimeseriesForecastTool(cfg, tracer)  # Granite wrapper

SYSTEM_PROMPT = (
    "You are an analytical assistant for tabular data. "
    "Decide which tools to call in order: "
    "1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document) 5) Forecast (Granite TTM). "
    "Always disclose the steps taken."
)

def plan_actions(message: str):
    if llm is not None:
        prompt = (
            f"{SYSTEM_PROMPT}\nUser: {message}\n"
            "Return JSON with fields: steps (array subset of ['sql','predict','explain','report','forecast']), rationale."
        )
        try:
            out = llm(prompt)[0]["generated_text"]
            last = out.split("\n")[-1].strip()
            obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
            if isinstance(obj, dict) and "steps" in obj:
                return obj
        except Exception:
            pass
    # Fallback heuristic:
    m = message.lower()
    steps = []
    if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]): steps.append("sql")
    if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]): steps.append("predict")
    if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]): steps.append("explain")
    if any(k in m for k in ["report", "download", "pdf", "summary"]): steps.append("report")
    if any(k in m for k in ["forecast", "next", "horizon", "granite"]): steps.append("forecast")
    if not steps: steps = ["sql"]
    return {"steps": steps, "rationale": "Rule-based plan."}

def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
    tracer.trace_event("user_message", {"message": message})
    plan = plan_actions(message)
    tracer.trace_event("plan", plan)

    sql_df = None
    predict_df = None
    explain_imgs = {}
    artifacts = {}
    ts_forecast_df = None

    if "sql" in plan["steps"]:
        sql_df = sql_tool.run(message)
        artifacts["sql_rows"] = int(len(sql_df)) if isinstance(sql_df, pd.DataFrame) else 0

    if "predict" in plan["steps"]:
        predict_df = predict_tool.run(sql_df)

    ts_df = None
    if sql_df is not None:
        try:
            ts_df = build_timeseries(sql_df)
        except Exception:
            ts_df = None

    if "forecast" in plan["steps"] and ts_df is not None:
        # Expect 'portfolio_value' after preprocessing
        # Use the combined series — e.g., sum over instruments by timestamp
        agg = ts_df.groupby("timestamp", as_index=True)["portfolio_value"].sum().sort_index()
        try:
            ts_forecast_df = ts_tool.zeroshot_forecast(agg, horizon=96)
        except Exception as e:
            # Surface a readable error in response
            ts_forecast_df = pd.DataFrame({"error": [str(e)]})

    if "explain" in plan["steps"]:
        explain_imgs = explain_tool.run(predict_df or sql_df)

    report_link = None
    if "report" in plan["steps"]:
        # Add forecast preview if available
        forecast_preview = ts_forecast_df.head(50) if isinstance(ts_forecast_df, pd.DataFrame) else None
        report_link = report_tool.render_and_save(
            user_query=message,
            sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
            predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else forecast_preview,
            explain_images=explain_imgs,
            plan=plan,
        )

    tracer.trace_event("hitl", {
        "message": message,
        "decision": hitl_decision,
        "reviewer_note": reviewer_note,
        "artifacts": artifacts,
        "plan": plan,
    })

    # Compose response
    response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
    if isinstance(sql_df, pd.DataFrame): response += f"\n**SQL rows:** {len(sql_df)}"
    if isinstance(predict_df, pd.DataFrame): response += f"\n**Predictions rows:** {len(predict_df)}"
    if isinstance(ts_forecast_df, pd.DataFrame) and "forecast" in ts_forecast_df.columns:
        response += f"\n**Forecast horizon:** {len(ts_forecast_df)}"
    if report_link: response += f"\n**Report:** {report_link}"
    if tracer.trace_url: response += f"\n**Trace:** {tracer.trace_url}"

    # Prefer to show forecast if present, else predictions, else raw query
    preview_df = ts_forecast_df if isinstance(ts_forecast_df, pd.DataFrame) and not ts_forecast_df.empty else \
                 (predict_df if isinstance(predict_df, pd.DataFrame) and not predict_df.empty else sql_df)
    return response, (preview_df if isinstance(preview_df, pd.DataFrame) else pd.DataFrame())

with gr.Blocks() as demo:
    gr.Markdown("# Tabular Agentic XAI (Free-Tier)")
    with gr.Row():
        msg = gr.Textbox(label="Ask your question")
    with gr.Row():
        hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
        note = gr.Textbox(label="Reviewer note (optional)")
    out_md = gr.Markdown()
    out_df = gr.Dataframe(interactive=False)
    ask = gr.Button("Run")
    ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])

if __name__ == "__main__":
    demo.launch()