ALM_LLM / app.py
AshenH's picture
enhanced app.py file
da25b2a verified
raw
history blame
16.6 kB
# space/app.py
import os
import json
import logging
import gradio as gr
import pandas as pd
from typing import Optional, Tuple
from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from tools.ts_preprocess import build_timeseries
from tools.ts_forecast_tool import TimeseriesForecastTool
from utils.tracing import Tracer
from utils.config import AppConfig
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Constants
MAX_RESPONSE_LENGTH = 10000
MAX_FORECAST_HORIZON = 365
DEFAULT_FORECAST_HORIZON = 96
# Optional LLM for planning
llm = None
LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
if LLM_ID:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
logger.info(f"Loading orchestrator model: {LLM_ID}")
_tok = AutoTokenizer.from_pretrained(LLM_ID)
_mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
logger.info("Orchestrator model loaded successfully")
except Exception as e:
logger.warning(f"Failed to load orchestrator model: {e}. Using fallback planner.")
llm = None
# Initialize configuration and tools
try:
cfg = AppConfig.from_env()
tracer = Tracer.from_env()
sql_tool = SQLTool(cfg, tracer)
predict_tool = PredictTool(cfg, tracer)
explain_tool = ExplainTool(cfg, tracer)
report_tool = ReportTool(cfg, tracer)
ts_tool = TimeseriesForecastTool(cfg, tracer)
logger.info("All tools initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize application: {e}")
raise
SYSTEM_PROMPT = (
"You are an analytical assistant for tabular data. "
"Decide which tools to call in order: "
"1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document) 5) Forecast (Granite TTM). "
"Always disclose the steps taken."
)
def validate_message(message: str) -> Tuple[bool, str]:
"""Validate user input message."""
if not message or not message.strip():
return False, "Please enter a valid question."
if len(message) > MAX_RESPONSE_LENGTH:
return False, f"Message too long. Please limit to {MAX_RESPONSE_LENGTH} characters."
# Basic SQL injection pattern detection
suspicious_patterns = [
r';\s*drop\s+table',
r';\s*delete\s+from',
r';\s*truncate',
r'union\s+select.*from',
r'exec\s*\(',
r'execute\s*\('
]
import re
msg_lower = message.lower()
for pattern in suspicious_patterns:
if re.search(pattern, msg_lower):
logger.warning(f"Suspicious SQL pattern detected: {pattern}")
return False, "Query contains potentially unsafe patterns. Please rephrase."
return True, ""
def plan_actions(message: str) -> dict:
"""
Determine which tools to execute based on the user message.
Uses LLM if available, otherwise falls back to heuristics.
"""
if llm is not None:
prompt = (
f"{SYSTEM_PROMPT}\nUser: {message}\n"
"Return JSON with fields: steps (array subset of ['sql','predict','explain','report','forecast']), rationale."
)
try:
out = llm(prompt)[0]["generated_text"]
last = out.split("\n")[-1].strip()
obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
if isinstance(obj, dict) and "steps" in obj:
# Validate steps
valid_steps = {'sql', 'predict', 'explain', 'report', 'forecast'}
obj["steps"] = [s for s in obj["steps"] if s in valid_steps]
if obj["steps"]:
logger.info(f"LLM plan: {obj['steps']}")
return obj
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse LLM output as JSON: {e}")
except Exception as e:
logger.warning(f"LLM planning failed: {e}")
# Fallback heuristic planning
m = message.lower()
steps = []
# SQL keywords
if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query",
"kpi", "data", "retrieve", "fetch", "list", "view"]):
steps.append("sql")
# Prediction keywords
if any(k in m for k in ["predict", "score", "risk", "propensity", "probability",
"classification", "regression"]):
steps.append("predict")
if "sql" not in steps:
steps.insert(0, "sql") # Need data first
# Explanation keywords
if any(k in m for k in ["why", "explain", "shap", "feature", "attribution",
"importance", "interpret"]):
steps.append("explain")
if "predict" not in steps:
steps.insert(0, "predict")
if "sql" not in steps:
steps.insert(0, "sql")
# Report keywords
if any(k in m for k in ["report", "download", "pdf", "summary", "document", "export"]):
steps.append("report")
# Forecast keywords
if any(k in m for k in ["forecast", "next", "horizon", "granite", "predict future",
"time series", "timeseries"]):
steps.append("forecast")
if "sql" not in steps:
steps.insert(0, "sql")
# Default to SQL if no steps identified
if not steps:
steps = ["sql"]
rationale = f"Rule-based plan based on keywords: {', '.join(steps)}"
logger.info(f"Heuristic plan: {steps}")
return {"steps": steps, "rationale": rationale}
def run_agent(
message: str,
hitl_decision: str = "Approve",
reviewer_note: str = ""
) -> Tuple[str, pd.DataFrame]:
"""
Main agent execution function.
Args:
message: User query
hitl_decision: Human-in-the-loop decision
reviewer_note: Optional review notes
Returns:
Tuple of (response_text, preview_dataframe)
"""
try:
# Validate input
is_valid, error_msg = validate_message(message)
if not is_valid:
logger.warning(f"Invalid message: {error_msg}")
return f"❌ **Error:** {error_msg}", pd.DataFrame()
tracer.trace_event("user_message", {"message": message[:500]}) # Limit traced message length
# Plan actions
try:
plan = plan_actions(message)
tracer.trace_event("plan", plan)
except Exception as e:
logger.error(f"Planning failed: {e}")
return f"❌ **Planning Error:** Unable to create execution plan. {str(e)}", pd.DataFrame()
# Initialize result containers
sql_df = None
predict_df = None
explain_imgs = {}
artifacts = {}
ts_forecast_df = None
errors = []
# Execute SQL step
if "sql" in plan["steps"]:
try:
sql_df = sql_tool.run(message)
if isinstance(sql_df, pd.DataFrame):
artifacts["sql_rows"] = len(sql_df)
logger.info(f"SQL returned {len(sql_df)} rows")
else:
errors.append("SQL query returned no data")
except Exception as e:
error_msg = f"SQL execution failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Execute prediction step
if "predict" in plan["steps"]:
try:
if sql_df is not None and not sql_df.empty:
predict_df = predict_tool.run(sql_df)
if isinstance(predict_df, pd.DataFrame):
artifacts["predict_rows"] = len(predict_df)
logger.info(f"Predictions generated for {len(predict_df)} rows")
else:
errors.append("Prediction skipped: no data available")
except Exception as e:
error_msg = f"Prediction failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Build time series if possible
ts_df = None
if sql_df is not None and not sql_df.empty:
try:
ts_df = build_timeseries(sql_df)
logger.info(f"Time series built with {len(ts_df)} records")
except Exception as e:
logger.info(f"Time series preprocessing skipped: {e}")
# Not always an error - data might not be suitable for TS
# Execute forecast step
if "forecast" in plan["steps"]:
if ts_df is not None and not ts_df.empty:
try:
# Aggregate portfolio value by timestamp
agg = ts_df.groupby("timestamp", as_index=True)["portfolio_value"].sum().sort_index()
if len(agg) < 2:
errors.append("Insufficient time series data for forecasting (need at least 2 points)")
else:
# Validate horizon
horizon = min(DEFAULT_FORECAST_HORIZON, MAX_FORECAST_HORIZON)
ts_forecast_df = ts_tool.zeroshot_forecast(agg, horizon=horizon)
if isinstance(ts_forecast_df, pd.DataFrame):
if "error" in ts_forecast_df.columns:
errors.append(f"Forecast error: {ts_forecast_df['error'].iloc[0]}")
ts_forecast_df = None
else:
artifacts["forecast_horizon"] = len(ts_forecast_df)
logger.info(f"Forecast generated for {len(ts_forecast_df)} periods")
except Exception as e:
error_msg = f"Forecasting failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
else:
errors.append("Forecast skipped: no suitable time series data")
# Execute explanation step
if "explain" in plan["steps"]:
try:
explain_data = predict_df if predict_df is not None else sql_df
if explain_data is not None and not explain_data.empty:
explain_imgs = explain_tool.run(explain_data)
artifacts["explain_charts"] = len(explain_imgs)
logger.info(f"Generated {len(explain_imgs)} explanation charts")
else:
errors.append("Explanation skipped: no data available")
except Exception as e:
error_msg = f"Explanation failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Execute report generation
report_link = None
if "report" in plan["steps"]:
try:
forecast_preview = ts_forecast_df.head(50) if isinstance(ts_forecast_df, pd.DataFrame) else None
report_link = report_tool.render_and_save(
user_query=message,
sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else forecast_preview,
explain_images=explain_imgs,
plan=plan,
)
logger.info(f"Report generated: {report_link}")
except Exception as e:
error_msg = f"Report generation failed: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
# Log human-in-the-loop decision
tracer.trace_event("hitl", {
"message": message[:500],
"decision": hitl_decision,
"reviewer_note": reviewer_note[:500] if reviewer_note else "",
"artifacts": artifacts,
"plan": plan,
"errors": errors,
})
# Compose response
response = f"**Plan:** {', '.join(plan['steps'])}\n\n**Rationale:** {plan['rationale']}\n\n"
# Add artifacts info
if artifacts:
response += "**Results:**\n"
if "sql_rows" in artifacts:
response += f"- SQL query returned {artifacts['sql_rows']} rows\n"
if "predict_rows" in artifacts:
response += f"- Generated predictions for {artifacts['predict_rows']} rows\n"
if "forecast_horizon" in artifacts:
response += f"- Forecast generated for {artifacts['forecast_horizon']} periods\n"
if "explain_charts" in artifacts:
response += f"- Created {artifacts['explain_charts']} explanation charts\n"
response += "\n"
# Add report link
if report_link:
response += f"πŸ“„ **Report:** {report_link}\n\n"
# Add trace URL
if tracer.trace_url:
response += f"πŸ” **Trace:** {tracer.trace_url}\n\n"
# Add errors if any
if errors:
response += "**⚠️ Warnings/Errors:**\n"
for err in errors:
response += f"- {err}\n"
# Determine preview dataframe
if isinstance(ts_forecast_df, pd.DataFrame) and not ts_forecast_df.empty:
preview_df = ts_forecast_df.head(100)
elif isinstance(predict_df, pd.DataFrame) and not predict_df.empty:
preview_df = predict_df.head(100)
elif isinstance(sql_df, pd.DataFrame) and not sql_df.empty:
preview_df = sql_df.head(100)
else:
preview_df = pd.DataFrame({"message": ["No data to display"]})
return response, preview_df
except Exception as e:
error_msg = f"Unexpected error in agent execution: {str(e)}"
logger.exception(error_msg)
tracer.trace_event("error", {"message": error_msg})
return f"❌ **Critical Error:** {error_msg}", pd.DataFrame()
# Gradio Interface
with gr.Blocks(title="Tabular Agentic XAI") as demo:
gr.Markdown("""
# πŸ€– Tabular Agentic XAI (Enterprise Edition)
An intelligent assistant for analyzing tabular data with ML predictions, explanations, and time-series forecasting.
**Capabilities:**
- πŸ“Š SQL queries and data retrieval
- 🎯 ML predictions with confidence scores
- πŸ” SHAP-based model explanations
- πŸ“ˆ Time-series forecasting with Granite TTM
- πŸ“„ Automated report generation
""")
with gr.Row():
msg = gr.Textbox(
label="Ask your question",
placeholder="e.g., Show me the top 10 customers by revenue, predict churn risk, forecast next quarter...",
lines=3
)
with gr.Row():
hitl = gr.Radio(
["Approve", "Needs Changes"],
value="Approve",
label="Human Review",
info="Review the planned actions before execution"
)
note = gr.Textbox(
label="Reviewer note (optional)",
placeholder="Add any review comments...",
lines=2
)
out_md = gr.Markdown(label="Response")
out_df = gr.Dataframe(
interactive=False,
label="Data Preview (max 100 rows)",
wrap=True
)
with gr.Row():
ask = gr.Button("πŸš€ Run Analysis", variant="primary")
clear = gr.Button("πŸ”„ Clear")
ask.click(
run_agent,
inputs=[msg, hitl, note],
outputs=[out_md, out_df]
)
clear.click(
lambda: ("", "Approve", "", "", pd.DataFrame()),
outputs=[msg, hitl, note, out_md, out_df]
)
gr.Markdown("""
---
**Tips:**
- Be specific in your queries for better results
- Use natural language - the system will interpret your intent
- Review the execution plan before approving
- Check the trace link for detailed execution logs
""")
if __name__ == "__main__":
logger.info("Starting Gradio application...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)