AshenH commited on
Commit
af53f4b
·
verified ·
1 Parent(s): a2fcbc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -138
app.py CHANGED
@@ -1,138 +1,124 @@
1
- import os
2
- import json
3
- import gradio as gr
4
- import pandas as pd
5
- from typing import Dict, Any
6
-
7
- from tools.sql_tool import SQLTool
8
- from tools.predict_tool import PredictTool
9
- from tools.explain_tool import ExplainTool
10
- from tools.report_tool import ReportTool
11
- from utils.tracing import Tracer
12
- from utils.config import AppConfig
13
-
14
- # Optional: tiny orchestration LLM (keep it simple on CPU)
15
- try:
16
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
17
- LLM_ID = os.getenv("ORCHESTRATOR_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
18
- _tok = AutoTokenizer.from_pretrained(LLM_ID)
19
- _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
20
- llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
21
- except Exception:
22
- llm = None # Fallback: deterministic tool routing without LLM
23
-
24
- cfg = AppConfig.from_env()
25
- tracer = Tracer.from_env()
26
-
27
- sql_tool = SQLTool(cfg, tracer)
28
- predict_tool = PredictTool(cfg, tracer)
29
- explain_tool = ExplainTool(cfg, tracer)
30
- report_tool = ReportTool(cfg, tracer)
31
-
32
- SYSTEM_PROMPT = (
33
- "You are an analytical assistant for tabular data. "
34
- "When the user asks a question, decide which tools to call in order: "
35
- "1) SQL (if data retrieval is needed) 2) Predict (if scoring is requested) "
36
- "3) Explain (if attributions or why-questions) 4) Report (if a document is requested). "
37
- "Always disclose the steps taken and include links to traces if available."
38
- )
39
-
40
-
41
- def plan_actions(message: str) -> Dict[str, Any]:
42
- """Very lightweight planner. Uses LLM if available, else rule-based heuristics."""
43
- if llm is not None:
44
- prompt = (
45
- f"{SYSTEM_PROMPT}\nUser: {message}\n"
46
- "Return JSON with fields: steps (array, subset of ['sql','predict','explain','report']), "
47
- "and rationale (one sentence)."
48
- )
49
- out = llm(prompt)[0]["generated_text"].split("\n")[-1]
50
- try:
51
- plan = json.loads(out)
52
- return plan
53
- except Exception:
54
- pass
55
- # Heuristic fallback
56
- steps = []
57
- m = message.lower()
58
- if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]):
59
- steps.append("sql")
60
- if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]):
61
- steps.append("predict")
62
- if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]):
63
- steps.append("explain")
64
- if any(k in m for k in ["report", "download", "pdf", "summary"]):
65
- steps.append("report")
66
- if not steps:
67
- steps = ["sql"]
68
- return {"steps": steps, "rationale": "Rule-based plan."}
69
-
70
-
71
- def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
72
- tracer.trace_event("user_message", {"message": message})
73
- plan = plan_actions(message)
74
- tracer.trace_event("plan", plan)
75
-
76
- sql_df = None
77
- predict_df = None
78
- explain_plots = {}
79
- artifacts = {}
80
-
81
- if "sql" in plan["steps"]:
82
- sql_df = sql_tool.run(message)
83
- artifacts["sql_rows"] = len(sql_df) if isinstance(sql_df, pd.DataFrame) else 0
84
-
85
- if "predict" in plan["steps"]:
86
- predict_df = predict_tool.run(sql_df)
87
-
88
- if "explain" in plan["steps"]:
89
- explain_plots = explain_tool.run(predict_df or sql_df)
90
-
91
- report_link = None
92
- if "report" in plan["steps"]:
93
- report_link = report_tool.render_and_save(
94
- user_query=message,
95
- sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
96
- predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else None,
97
- explain_images=explain_plots,
98
- plan=plan,
99
- )
100
-
101
- # HITL log (append-only). In production, push to a private HF dataset via API.
102
- hitl_record = {
103
- "message": message,
104
- "decision": hitl_decision,
105
- "reviewer_note": reviewer_note,
106
- "timestamp": pd.Timestamp.utcnow().isoformat(),
107
- "artifacts": artifacts,
108
- "plan": plan,
109
- }
110
- tracer.trace_event("hitl", hitl_record)
111
-
112
- response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
113
- if isinstance(sql_df, pd.DataFrame):
114
- response += f"\n**SQL rows:** {len(sql_df)}"
115
- if isinstance(predict_df, pd.DataFrame):
116
- response += f"\n**Predictions rows:** {len(predict_df)}"
117
- if report_link:
118
- response += f"\n**Report:** {report_link}"
119
- if tracer.trace_url:
120
- response += f"\n**Trace:** {tracer.trace_url}"
121
-
122
- preview_df = predict_df or sql_df
123
- return response, preview_df
124
-
125
- with gr.Blocks() as demo:
126
- gr.Markdown("# Tabular Agentic XAI (Free‑Tier)")
127
- with gr.Row():
128
- msg = gr.Textbox(label="Ask your question")
129
- with gr.Row():
130
- hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
131
- note = gr.Textbox(label="Reviewer note (optional)")
132
- out_md = gr.Markdown()
133
- out_df = gr.Dataframe(interactive=False)
134
- ask = gr.Button("Run")
135
- ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])
136
-
137
- if __name__ == "__main__":
138
- demo.launch()
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ import pandas as pd
5
+
6
+ from tools.sql_tool import SQLTool
7
+ from tools.predict_tool import PredictTool
8
+ from tools.explain_tool import ExplainTool
9
+ from tools.report_tool import ReportTool
10
+ from utils.tracing import Tracer
11
+ from utils.config import AppConfig
12
+
13
+ # Optional tiny CPU LLM for planning (can be disabled by not setting ORCHESTRATOR_MODEL)
14
+ llm = None
15
+ LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
16
+ if LLM_ID:
17
+ try:
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
19
+ _tok = AutoTokenizer.from_pretrained(LLM_ID)
20
+ _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
21
+ llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
22
+ except Exception:
23
+ llm = None
24
+
25
+ cfg = AppConfig.from_env()
26
+ tracer = Tracer.from_env()
27
+
28
+ sql_tool = SQLTool(cfg, tracer)
29
+ predict_tool = PredictTool(cfg, tracer)
30
+ explain_tool = ExplainTool(cfg, tracer)
31
+ report_tool = ReportTool(cfg, tracer)
32
+
33
+ SYSTEM_PROMPT = (
34
+ "You are an analytical assistant for tabular data. "
35
+ "Decide which tools to call in order: "
36
+ "1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document). "
37
+ "Always disclose the steps taken."
38
+ )
39
+
40
+ def plan_actions(message: str):
41
+ if llm is not None:
42
+ prompt = (
43
+ f"{SYSTEM_PROMPT}\nUser: {message}\n"
44
+ "Return JSON with fields: steps (array subset of ['sql','predict','explain','report']), rationale."
45
+ )
46
+ try:
47
+ out = llm(prompt)[0]["generated_text"]
48
+ last = out.split("\n")[-1].strip()
49
+ obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
50
+ if isinstance(obj, dict) and "steps" in obj:
51
+ return obj
52
+ except Exception:
53
+ pass
54
+ # Fallback heuristic:
55
+ m = message.lower()
56
+ steps = []
57
+ if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]): steps.append("sql")
58
+ if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]): steps.append("predict")
59
+ if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]): steps.append("explain")
60
+ if any(k in m for k in ["report", "download", "pdf", "summary"]): steps.append("report")
61
+ if not steps: steps = ["sql"]
62
+ return {"steps": steps, "rationale": "Rule-based plan."}
63
+
64
+ def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
65
+ tracer.trace_event("user_message", {"message": message})
66
+ plan = plan_actions(message)
67
+ tracer.trace_event("plan", plan)
68
+
69
+ sql_df = None
70
+ predict_df = None
71
+ explain_imgs = {}
72
+ artifacts = {}
73
+
74
+ if "sql" in plan["steps"]:
75
+ sql_df = sql_tool.run(message)
76
+ artifacts["sql_rows"] = int(len(sql_df)) if isinstance(sql_df, pd.DataFrame) else 0
77
+
78
+ if "predict" in plan["steps"]:
79
+ predict_df = predict_tool.run(sql_df)
80
+
81
+ if "explain" in plan["steps"]:
82
+ explain_imgs = explain_tool.run(predict_df or sql_df)
83
+
84
+ report_link = None
85
+ if "report" in plan["steps"]:
86
+ report_link = report_tool.render_and_save(
87
+ user_query=message,
88
+ sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
89
+ predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else None,
90
+ explain_images=explain_imgs,
91
+ plan=plan,
92
+ )
93
+
94
+ tracer.trace_event("hitl", {
95
+ "message": message,
96
+ "decision": hitl_decision,
97
+ "reviewer_note": reviewer_note,
98
+ "artifacts": artifacts,
99
+ "plan": plan,
100
+ })
101
+
102
+ response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
103
+ if isinstance(sql_df, pd.DataFrame): response += f"\n**SQL rows:** {len(sql_df)}"
104
+ if isinstance(predict_df, pd.DataFrame): response += f"\n**Predictions rows:** {len(predict_df)}"
105
+ if report_link: response += f"\n**Report:** {report_link}"
106
+ if tracer.trace_url: response += f"\n**Trace:** {tracer.trace_url}"
107
+
108
+ preview_df = predict_df if isinstance(predict_df, pd.DataFrame) and len(predict_df) else sql_df
109
+ return response, (preview_df if isinstance(preview_df, pd.DataFrame) else pd.DataFrame())
110
+
111
+ with gr.Blocks() as demo:
112
+ gr.Markdown("# Tabular Agentic XAI (Free-Tier)")
113
+ with gr.Row():
114
+ msg = gr.Textbox(label="Ask your question")
115
+ with gr.Row():
116
+ hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
117
+ note = gr.Textbox(label="Reviewer note (optional)")
118
+ out_md = gr.Markdown()
119
+ out_df = gr.Dataframe(interactive=False)
120
+ ask = gr.Button("Run")
121
+ ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])
122
+
123
+ if __name__ == "__main__":
124
+ demo.launch()