File size: 6,342 Bytes
68c51bb af53f4b 68c51bb af53f4b 68c51bb c02152a af53f4b 68c51bb af53f4b 68c51bb af53f4b 68c51bb af53f4b 68c51bb af53f4b cd9ff90 68c51bb cd9ff90 af53f4b 68c51bb af53f4b 68c51bb af53f4b 68c51bb af53f4b 68c51bb af53f4b 68c51bb af53f4b 68c51bb af53f4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# space/app.py
import os
import json
import gradio as gr
import pandas as pd
from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from tools.ts_preprocess import build_timeseries
from tools.ts_forecast_tool import TimeseriesForecastTool
from utils.tracing import Tracer
from utils.config import AppConfig
# Optional tiny CPU LLM for planning (can be disabled by not setting ORCHESTRATOR_MODEL)
llm = None
LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
if LLM_ID:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
_tok = AutoTokenizer.from_pretrained(LLM_ID)
_mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
except Exception:
llm = None
cfg = AppConfig.from_env()
tracer = Tracer.from_env()
sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)
ts_tool = TimeseriesForecastTool(cfg, tracer) # Granite wrapper
SYSTEM_PROMPT = (
"You are an analytical assistant for tabular data. "
"Decide which tools to call in order: "
"1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document) 5) Forecast (Granite TTM). "
"Always disclose the steps taken."
)
def plan_actions(message: str):
if llm is not None:
prompt = (
f"{SYSTEM_PROMPT}\nUser: {message}\n"
"Return JSON with fields: steps (array subset of ['sql','predict','explain','report','forecast']), rationale."
)
try:
out = llm(prompt)[0]["generated_text"]
last = out.split("\n")[-1].strip()
obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
if isinstance(obj, dict) and "steps" in obj:
return obj
except Exception:
pass
# Fallback heuristic:
m = message.lower()
steps = []
if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]): steps.append("sql")
if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]): steps.append("predict")
if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]): steps.append("explain")
if any(k in m for k in ["report", "download", "pdf", "summary"]): steps.append("report")
if any(k in m for k in ["forecast", "next", "horizon", "granite"]): steps.append("forecast")
if not steps: steps = ["sql"]
return {"steps": steps, "rationale": "Rule-based plan."}
def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
tracer.trace_event("user_message", {"message": message})
plan = plan_actions(message)
tracer.trace_event("plan", plan)
sql_df = None
predict_df = None
explain_imgs = {}
artifacts = {}
ts_forecast_df = None
if "sql" in plan["steps"]:
sql_df = sql_tool.run(message)
artifacts["sql_rows"] = int(len(sql_df)) if isinstance(sql_df, pd.DataFrame) else 0
if "predict" in plan["steps"]:
predict_df = predict_tool.run(sql_df)
ts_df = None
if sql_df is not None:
try:
ts_df = build_timeseries(sql_df)
except Exception:
ts_df = None
if "forecast" in plan["steps"] and ts_df is not None:
# Expect 'portfolio_value' after preprocessing
# Use the combined series — e.g., sum over instruments by timestamp
agg = ts_df.groupby("timestamp", as_index=True)["portfolio_value"].sum().sort_index()
try:
ts_forecast_df = ts_tool.zeroshot_forecast(agg, horizon=96)
except Exception as e:
# Surface a readable error in response
ts_forecast_df = pd.DataFrame({"error": [str(e)]})
if "explain" in plan["steps"]:
explain_imgs = explain_tool.run(predict_df or sql_df)
report_link = None
if "report" in plan["steps"]:
# Add forecast preview if available
forecast_preview = ts_forecast_df.head(50) if isinstance(ts_forecast_df, pd.DataFrame) else None
report_link = report_tool.render_and_save(
user_query=message,
sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else forecast_preview,
explain_images=explain_imgs,
plan=plan,
)
tracer.trace_event("hitl", {
"message": message,
"decision": hitl_decision,
"reviewer_note": reviewer_note,
"artifacts": artifacts,
"plan": plan,
})
# Compose response
response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
if isinstance(sql_df, pd.DataFrame): response += f"\n**SQL rows:** {len(sql_df)}"
if isinstance(predict_df, pd.DataFrame): response += f"\n**Predictions rows:** {len(predict_df)}"
if isinstance(ts_forecast_df, pd.DataFrame) and "forecast" in ts_forecast_df.columns:
response += f"\n**Forecast horizon:** {len(ts_forecast_df)}"
if report_link: response += f"\n**Report:** {report_link}"
if tracer.trace_url: response += f"\n**Trace:** {tracer.trace_url}"
# Prefer to show forecast if present, else predictions, else raw query
preview_df = ts_forecast_df if isinstance(ts_forecast_df, pd.DataFrame) and not ts_forecast_df.empty else \
(predict_df if isinstance(predict_df, pd.DataFrame) and not predict_df.empty else sql_df)
return response, (preview_df if isinstance(preview_df, pd.DataFrame) else pd.DataFrame())
with gr.Blocks() as demo:
gr.Markdown("# Tabular Agentic XAI (Free-Tier)")
with gr.Row():
msg = gr.Textbox(label="Ask your question")
with gr.Row():
hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
note = gr.Textbox(label="Reviewer note (optional)")
out_md = gr.Markdown()
out_df = gr.Dataframe(interactive=False)
ask = gr.Button("Run")
ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])
if __name__ == "__main__":
demo.launch()
|