File size: 5,221 Bytes
a772359 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import json
import gradio as gr
import pandas as pd
from typing import Dict, Any
from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from utils.tracing import Tracer
from utils.config import AppConfig
# Optional: tiny orchestration LLM (keep it simple on CPU)
try:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
LLM_ID = os.getenv("ORCHESTRATOR_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
_tok = AutoTokenizer.from_pretrained(LLM_ID)
_mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
except Exception:
llm = None # Fallback: deterministic tool routing without LLM
cfg = AppConfig.from_env()
tracer = Tracer.from_env()
sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)
SYSTEM_PROMPT = (
"You are an analytical assistant for tabular data. "
"When the user asks a question, decide which tools to call in order: "
"1) SQL (if data retrieval is needed) 2) Predict (if scoring is requested) "
"3) Explain (if attributions or why-questions) 4) Report (if a document is requested). "
"Always disclose the steps taken and include links to traces if available."
)
def plan_actions(message: str) -> Dict[str, Any]:
"""Very lightweight planner. Uses LLM if available, else rule-based heuristics."""
if llm is not None:
prompt = (
f"{SYSTEM_PROMPT}\nUser: {message}\n"
"Return JSON with fields: steps (array, subset of ['sql','predict','explain','report']), "
"and rationale (one sentence)."
)
out = llm(prompt)[0]["generated_text"].split("\n")[-1]
try:
plan = json.loads(out)
return plan
except Exception:
pass
# Heuristic fallback
steps = []
m = message.lower()
if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]):
steps.append("sql")
if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]):
steps.append("predict")
if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]):
steps.append("explain")
if any(k in m for k in ["report", "download", "pdf", "summary"]):
steps.append("report")
if not steps:
steps = ["sql"]
return {"steps": steps, "rationale": "Rule-based plan."}
def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
tracer.trace_event("user_message", {"message": message})
plan = plan_actions(message)
tracer.trace_event("plan", plan)
sql_df = None
predict_df = None
explain_plots = {}
artifacts = {}
if "sql" in plan["steps"]:
sql_df = sql_tool.run(message)
artifacts["sql_rows"] = len(sql_df) if isinstance(sql_df, pd.DataFrame) else 0
if "predict" in plan["steps"]:
predict_df = predict_tool.run(sql_df)
if "explain" in plan["steps"]:
explain_plots = explain_tool.run(predict_df or sql_df)
report_link = None
if "report" in plan["steps"]:
report_link = report_tool.render_and_save(
user_query=message,
sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else None,
explain_images=explain_plots,
plan=plan,
)
# HITL log (append-only). In production, push to a private HF dataset via API.
hitl_record = {
"message": message,
"decision": hitl_decision,
"reviewer_note": reviewer_note,
"timestamp": pd.Timestamp.utcnow().isoformat(),
"artifacts": artifacts,
"plan": plan,
}
tracer.trace_event("hitl", hitl_record)
response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
if isinstance(sql_df, pd.DataFrame):
response += f"\n**SQL rows:** {len(sql_df)}"
if isinstance(predict_df, pd.DataFrame):
response += f"\n**Predictions rows:** {len(predict_df)}"
if report_link:
response += f"\n**Report:** {report_link}"
if tracer.trace_url:
response += f"\n**Trace:** {tracer.trace_url}"
preview_df = predict_df or sql_df
return response, preview_df
with gr.Blocks() as demo:
gr.Markdown("# Tabular Agentic XAI (Free‑Tier)")
with gr.Row():
msg = gr.Textbox(label="Ask your question")
with gr.Row():
hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
note = gr.Textbox(label="Reviewer note (optional)")
out_md = gr.Markdown()
out_df = gr.Dataframe(interactive=False)
ask = gr.Button("Run")
ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])
if __name__ == "__main__":
demo.launch() |