ALM_LLM / app.py
AshenH's picture
Upload app.py and readme
a772359 verified
raw
history blame
5.22 kB
import os
import json
import gradio as gr
import pandas as pd
from typing import Dict, Any
from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from utils.tracing import Tracer
from utils.config import AppConfig
# Optional: tiny orchestration LLM (keep it simple on CPU)
try:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
LLM_ID = os.getenv("ORCHESTRATOR_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
_tok = AutoTokenizer.from_pretrained(LLM_ID)
_mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
except Exception:
llm = None # Fallback: deterministic tool routing without LLM
cfg = AppConfig.from_env()
tracer = Tracer.from_env()
sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)
SYSTEM_PROMPT = (
"You are an analytical assistant for tabular data. "
"When the user asks a question, decide which tools to call in order: "
"1) SQL (if data retrieval is needed) 2) Predict (if scoring is requested) "
"3) Explain (if attributions or why-questions) 4) Report (if a document is requested). "
"Always disclose the steps taken and include links to traces if available."
)
def plan_actions(message: str) -> Dict[str, Any]:
"""Very lightweight planner. Uses LLM if available, else rule-based heuristics."""
if llm is not None:
prompt = (
f"{SYSTEM_PROMPT}\nUser: {message}\n"
"Return JSON with fields: steps (array, subset of ['sql','predict','explain','report']), "
"and rationale (one sentence)."
)
out = llm(prompt)[0]["generated_text"].split("\n")[-1]
try:
plan = json.loads(out)
return plan
except Exception:
pass
# Heuristic fallback
steps = []
m = message.lower()
if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]):
steps.append("sql")
if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]):
steps.append("predict")
if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]):
steps.append("explain")
if any(k in m for k in ["report", "download", "pdf", "summary"]):
steps.append("report")
if not steps:
steps = ["sql"]
return {"steps": steps, "rationale": "Rule-based plan."}
def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
tracer.trace_event("user_message", {"message": message})
plan = plan_actions(message)
tracer.trace_event("plan", plan)
sql_df = None
predict_df = None
explain_plots = {}
artifacts = {}
if "sql" in plan["steps"]:
sql_df = sql_tool.run(message)
artifacts["sql_rows"] = len(sql_df) if isinstance(sql_df, pd.DataFrame) else 0
if "predict" in plan["steps"]:
predict_df = predict_tool.run(sql_df)
if "explain" in plan["steps"]:
explain_plots = explain_tool.run(predict_df or sql_df)
report_link = None
if "report" in plan["steps"]:
report_link = report_tool.render_and_save(
user_query=message,
sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else None,
explain_images=explain_plots,
plan=plan,
)
# HITL log (append-only). In production, push to a private HF dataset via API.
hitl_record = {
"message": message,
"decision": hitl_decision,
"reviewer_note": reviewer_note,
"timestamp": pd.Timestamp.utcnow().isoformat(),
"artifacts": artifacts,
"plan": plan,
}
tracer.trace_event("hitl", hitl_record)
response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
if isinstance(sql_df, pd.DataFrame):
response += f"\n**SQL rows:** {len(sql_df)}"
if isinstance(predict_df, pd.DataFrame):
response += f"\n**Predictions rows:** {len(predict_df)}"
if report_link:
response += f"\n**Report:** {report_link}"
if tracer.trace_url:
response += f"\n**Trace:** {tracer.trace_url}"
preview_df = predict_df or sql_df
return response, preview_df
with gr.Blocks() as demo:
gr.Markdown("# Tabular Agentic XAI (Free‑Tier)")
with gr.Row():
msg = gr.Textbox(label="Ask your question")
with gr.Row():
hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
note = gr.Textbox(label="Reviewer note (optional)")
out_md = gr.Markdown()
out_df = gr.Dataframe(interactive=False)
ask = gr.Button("Run")
ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])
if __name__ == "__main__":
demo.launch()