File size: 5,221 Bytes
a772359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import json
import gradio as gr
import pandas as pd
from typing import Dict, Any

from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from utils.tracing import Tracer
from utils.config import AppConfig

# Optional: tiny orchestration LLM (keep it simple on CPU)
try:
    from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
    LLM_ID = os.getenv("ORCHESTRATOR_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
    _tok = AutoTokenizer.from_pretrained(LLM_ID)
    _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
    llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
except Exception:
    llm = None  # Fallback: deterministic tool routing without LLM

cfg = AppConfig.from_env()
tracer = Tracer.from_env()

sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)

SYSTEM_PROMPT = (
    "You are an analytical assistant for tabular data. "
    "When the user asks a question, decide which tools to call in order: "
    "1) SQL (if data retrieval is needed) 2) Predict (if scoring is requested) "
    "3) Explain (if attributions or why-questions) 4) Report (if a document is requested). "
    "Always disclose the steps taken and include links to traces if available."
)


def plan_actions(message: str) -> Dict[str, Any]:
    """Very lightweight planner. Uses LLM if available, else rule-based heuristics."""
    if llm is not None:
        prompt = (
            f"{SYSTEM_PROMPT}\nUser: {message}\n"
            "Return JSON with fields: steps (array, subset of ['sql','predict','explain','report']), "
            "and rationale (one sentence)."
        )
        out = llm(prompt)[0]["generated_text"].split("\n")[-1]
        try:
            plan = json.loads(out)
            return plan
        except Exception:
            pass
    # Heuristic fallback
    steps = []
    m = message.lower()
    if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]):
        steps.append("sql")
    if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]):
        steps.append("predict")
    if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]):
        steps.append("explain")
    if any(k in m for k in ["report", "download", "pdf", "summary"]):
        steps.append("report")
    if not steps:
        steps = ["sql"]
    return {"steps": steps, "rationale": "Rule-based plan."}


def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
    tracer.trace_event("user_message", {"message": message})
    plan = plan_actions(message)
    tracer.trace_event("plan", plan)

    sql_df = None
    predict_df = None
    explain_plots = {}
    artifacts = {}

    if "sql" in plan["steps"]:
        sql_df = sql_tool.run(message)
        artifacts["sql_rows"] = len(sql_df) if isinstance(sql_df, pd.DataFrame) else 0

    if "predict" in plan["steps"]:
        predict_df = predict_tool.run(sql_df)

    if "explain" in plan["steps"]:
        explain_plots = explain_tool.run(predict_df or sql_df)

    report_link = None
    if "report" in plan["steps"]:
        report_link = report_tool.render_and_save(
            user_query=message,
            sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
            predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else None,
            explain_images=explain_plots,
            plan=plan,
        )

    # HITL log (append-only). In production, push to a private HF dataset via API.
    hitl_record = {
        "message": message,
        "decision": hitl_decision,
        "reviewer_note": reviewer_note,
        "timestamp": pd.Timestamp.utcnow().isoformat(),
        "artifacts": artifacts,
        "plan": plan,
    }
    tracer.trace_event("hitl", hitl_record)

    response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
    if isinstance(sql_df, pd.DataFrame):
        response += f"\n**SQL rows:** {len(sql_df)}"
    if isinstance(predict_df, pd.DataFrame):
        response += f"\n**Predictions rows:** {len(predict_df)}"
    if report_link:
        response += f"\n**Report:** {report_link}"
    if tracer.trace_url:
        response += f"\n**Trace:** {tracer.trace_url}"

    preview_df = predict_df or sql_df
    return response, preview_df

with gr.Blocks() as demo:
    gr.Markdown("# Tabular Agentic XAI (Free‑Tier)")
    with gr.Row():
        msg = gr.Textbox(label="Ask your question")
    with gr.Row():
        hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
        note = gr.Textbox(label="Reviewer note (optional)")
    out_md = gr.Markdown()
    out_df = gr.Dataframe(interactive=False)
    ask = gr.Button("Run")
    ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])

if __name__ == "__main__":
    demo.launch()