File size: 16,618 Bytes
68c51bb
af53f4b
 
da25b2a
af53f4b
 
da25b2a
af53f4b
 
 
 
 
68c51bb
 
 
af53f4b
 
 
da25b2a
 
 
 
 
 
 
 
 
 
 
 
 
af53f4b
 
 
 
 
da25b2a
af53f4b
 
 
da25b2a
 
 
af53f4b
 
da25b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c02152a
af53f4b
 
 
68c51bb
af53f4b
 
 
da25b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af53f4b
 
 
68c51bb
af53f4b
 
 
 
 
da25b2a
af53f4b
da25b2a
 
 
 
 
 
 
 
 
 
 
 
af53f4b
 
da25b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68c51bb
da25b2a
 
68c51bb
da25b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af53f4b
da25b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af53f4b
da25b2a
 
 
 
 
 
af53f4b
da25b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af53f4b
 
da25b2a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
# space/app.py
import os
import json
import logging
import gradio as gr
import pandas as pd
from typing import Optional, Tuple

from tools.sql_tool import SQLTool
from tools.predict_tool import PredictTool
from tools.explain_tool import ExplainTool
from tools.report_tool import ReportTool
from tools.ts_preprocess import build_timeseries
from tools.ts_forecast_tool import TimeseriesForecastTool

from utils.tracing import Tracer
from utils.config import AppConfig

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Constants
MAX_RESPONSE_LENGTH = 10000
MAX_FORECAST_HORIZON = 365
DEFAULT_FORECAST_HORIZON = 96

# Optional LLM for planning
llm = None
LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
if LLM_ID:
    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
        logger.info(f"Loading orchestrator model: {LLM_ID}")
        _tok = AutoTokenizer.from_pretrained(LLM_ID)
        _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
        llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
        logger.info("Orchestrator model loaded successfully")
    except Exception as e:
        logger.warning(f"Failed to load orchestrator model: {e}. Using fallback planner.")
        llm = None

# Initialize configuration and tools
try:
    cfg = AppConfig.from_env()
    tracer = Tracer.from_env()
    
    sql_tool = SQLTool(cfg, tracer)
    predict_tool = PredictTool(cfg, tracer)
    explain_tool = ExplainTool(cfg, tracer)
    report_tool = ReportTool(cfg, tracer)
    ts_tool = TimeseriesForecastTool(cfg, tracer)
    
    logger.info("All tools initialized successfully")
except Exception as e:
    logger.error(f"Failed to initialize application: {e}")
    raise

SYSTEM_PROMPT = (
    "You are an analytical assistant for tabular data. "
    "Decide which tools to call in order: "
    "1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document) 5) Forecast (Granite TTM). "
    "Always disclose the steps taken."
)


def validate_message(message: str) -> Tuple[bool, str]:
    """Validate user input message."""
    if not message or not message.strip():
        return False, "Please enter a valid question."
    
    if len(message) > MAX_RESPONSE_LENGTH:
        return False, f"Message too long. Please limit to {MAX_RESPONSE_LENGTH} characters."
    
    # Basic SQL injection pattern detection
    suspicious_patterns = [
        r';\s*drop\s+table',
        r';\s*delete\s+from',
        r';\s*truncate',
        r'union\s+select.*from',
        r'exec\s*\(',
        r'execute\s*\('
    ]
    
    import re
    msg_lower = message.lower()
    for pattern in suspicious_patterns:
        if re.search(pattern, msg_lower):
            logger.warning(f"Suspicious SQL pattern detected: {pattern}")
            return False, "Query contains potentially unsafe patterns. Please rephrase."
    
    return True, ""


def plan_actions(message: str) -> dict:
    """
    Determine which tools to execute based on the user message.
    Uses LLM if available, otherwise falls back to heuristics.
    """
    if llm is not None:
        prompt = (
            f"{SYSTEM_PROMPT}\nUser: {message}\n"
            "Return JSON with fields: steps (array subset of ['sql','predict','explain','report','forecast']), rationale."
        )
        try:
            out = llm(prompt)[0]["generated_text"]
            last = out.split("\n")[-1].strip()
            obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
            
            if isinstance(obj, dict) and "steps" in obj:
                # Validate steps
                valid_steps = {'sql', 'predict', 'explain', 'report', 'forecast'}
                obj["steps"] = [s for s in obj["steps"] if s in valid_steps]
                if obj["steps"]:
                    logger.info(f"LLM plan: {obj['steps']}")
                    return obj
        except json.JSONDecodeError as e:
            logger.warning(f"Failed to parse LLM output as JSON: {e}")
        except Exception as e:
            logger.warning(f"LLM planning failed: {e}")
    
    # Fallback heuristic planning
    m = message.lower()
    steps = []
    
    # SQL keywords
    if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", 
                             "kpi", "data", "retrieve", "fetch", "list", "view"]):
        steps.append("sql")
    
    # Prediction keywords
    if any(k in m for k in ["predict", "score", "risk", "propensity", "probability", 
                             "classification", "regression"]):
        steps.append("predict")
        if "sql" not in steps:
            steps.insert(0, "sql")  # Need data first
    
    # Explanation keywords
    if any(k in m for k in ["why", "explain", "shap", "feature", "attribution", 
                             "importance", "interpret"]):
        steps.append("explain")
        if "predict" not in steps:
            steps.insert(0, "predict")
        if "sql" not in steps:
            steps.insert(0, "sql")
    
    # Report keywords
    if any(k in m for k in ["report", "download", "pdf", "summary", "document", "export"]):
        steps.append("report")
    
    # Forecast keywords
    if any(k in m for k in ["forecast", "next", "horizon", "granite", "predict future", 
                             "time series", "timeseries"]):
        steps.append("forecast")
        if "sql" not in steps:
            steps.insert(0, "sql")
    
    # Default to SQL if no steps identified
    if not steps:
        steps = ["sql"]
    
    rationale = f"Rule-based plan based on keywords: {', '.join(steps)}"
    logger.info(f"Heuristic plan: {steps}")
    return {"steps": steps, "rationale": rationale}


def run_agent(
    message: str, 
    hitl_decision: str = "Approve", 
    reviewer_note: str = ""
) -> Tuple[str, pd.DataFrame]:
    """
    Main agent execution function.
    
    Args:
        message: User query
        hitl_decision: Human-in-the-loop decision
        reviewer_note: Optional review notes
    
    Returns:
        Tuple of (response_text, preview_dataframe)
    """
    try:
        # Validate input
        is_valid, error_msg = validate_message(message)
        if not is_valid:
            logger.warning(f"Invalid message: {error_msg}")
            return f"❌ **Error:** {error_msg}", pd.DataFrame()
        
        tracer.trace_event("user_message", {"message": message[:500]})  # Limit traced message length
        
        # Plan actions
        try:
            plan = plan_actions(message)
            tracer.trace_event("plan", plan)
        except Exception as e:
            logger.error(f"Planning failed: {e}")
            return f"❌ **Planning Error:** Unable to create execution plan. {str(e)}", pd.DataFrame()
        
        # Initialize result containers
        sql_df = None
        predict_df = None
        explain_imgs = {}
        artifacts = {}
        ts_forecast_df = None
        errors = []
        
        # Execute SQL step
        if "sql" in plan["steps"]:
            try:
                sql_df = sql_tool.run(message)
                if isinstance(sql_df, pd.DataFrame):
                    artifacts["sql_rows"] = len(sql_df)
                    logger.info(f"SQL returned {len(sql_df)} rows")
                else:
                    errors.append("SQL query returned no data")
            except Exception as e:
                error_msg = f"SQL execution failed: {str(e)}"
                logger.error(error_msg)
                errors.append(error_msg)
        
        # Execute prediction step
        if "predict" in plan["steps"]:
            try:
                if sql_df is not None and not sql_df.empty:
                    predict_df = predict_tool.run(sql_df)
                    if isinstance(predict_df, pd.DataFrame):
                        artifacts["predict_rows"] = len(predict_df)
                        logger.info(f"Predictions generated for {len(predict_df)} rows")
                else:
                    errors.append("Prediction skipped: no data available")
            except Exception as e:
                error_msg = f"Prediction failed: {str(e)}"
                logger.error(error_msg)
                errors.append(error_msg)
        
        # Build time series if possible
        ts_df = None
        if sql_df is not None and not sql_df.empty:
            try:
                ts_df = build_timeseries(sql_df)
                logger.info(f"Time series built with {len(ts_df)} records")
            except Exception as e:
                logger.info(f"Time series preprocessing skipped: {e}")
                # Not always an error - data might not be suitable for TS
        
        # Execute forecast step
        if "forecast" in plan["steps"]:
            if ts_df is not None and not ts_df.empty:
                try:
                    # Aggregate portfolio value by timestamp
                    agg = ts_df.groupby("timestamp", as_index=True)["portfolio_value"].sum().sort_index()
                    
                    if len(agg) < 2:
                        errors.append("Insufficient time series data for forecasting (need at least 2 points)")
                    else:
                        # Validate horizon
                        horizon = min(DEFAULT_FORECAST_HORIZON, MAX_FORECAST_HORIZON)
                        ts_forecast_df = ts_tool.zeroshot_forecast(agg, horizon=horizon)
                        
                        if isinstance(ts_forecast_df, pd.DataFrame):
                            if "error" in ts_forecast_df.columns:
                                errors.append(f"Forecast error: {ts_forecast_df['error'].iloc[0]}")
                                ts_forecast_df = None
                            else:
                                artifacts["forecast_horizon"] = len(ts_forecast_df)
                                logger.info(f"Forecast generated for {len(ts_forecast_df)} periods")
                except Exception as e:
                    error_msg = f"Forecasting failed: {str(e)}"
                    logger.error(error_msg)
                    errors.append(error_msg)
            else:
                errors.append("Forecast skipped: no suitable time series data")
        
        # Execute explanation step
        if "explain" in plan["steps"]:
            try:
                explain_data = predict_df if predict_df is not None else sql_df
                if explain_data is not None and not explain_data.empty:
                    explain_imgs = explain_tool.run(explain_data)
                    artifacts["explain_charts"] = len(explain_imgs)
                    logger.info(f"Generated {len(explain_imgs)} explanation charts")
                else:
                    errors.append("Explanation skipped: no data available")
            except Exception as e:
                error_msg = f"Explanation failed: {str(e)}"
                logger.error(error_msg)
                errors.append(error_msg)
        
        # Execute report generation
        report_link = None
        if "report" in plan["steps"]:
            try:
                forecast_preview = ts_forecast_df.head(50) if isinstance(ts_forecast_df, pd.DataFrame) else None
                report_link = report_tool.render_and_save(
                    user_query=message,
                    sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
                    predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else forecast_preview,
                    explain_images=explain_imgs,
                    plan=plan,
                )
                logger.info(f"Report generated: {report_link}")
            except Exception as e:
                error_msg = f"Report generation failed: {str(e)}"
                logger.error(error_msg)
                errors.append(error_msg)
        
        # Log human-in-the-loop decision
        tracer.trace_event("hitl", {
            "message": message[:500],
            "decision": hitl_decision,
            "reviewer_note": reviewer_note[:500] if reviewer_note else "",
            "artifacts": artifacts,
            "plan": plan,
            "errors": errors,
        })
        
        # Compose response
        response = f"**Plan:** {', '.join(plan['steps'])}\n\n**Rationale:** {plan['rationale']}\n\n"
        
        # Add artifacts info
        if artifacts:
            response += "**Results:**\n"
            if "sql_rows" in artifacts:
                response += f"- SQL query returned {artifacts['sql_rows']} rows\n"
            if "predict_rows" in artifacts:
                response += f"- Generated predictions for {artifacts['predict_rows']} rows\n"
            if "forecast_horizon" in artifacts:
                response += f"- Forecast generated for {artifacts['forecast_horizon']} periods\n"
            if "explain_charts" in artifacts:
                response += f"- Created {artifacts['explain_charts']} explanation charts\n"
            response += "\n"
        
        # Add report link
        if report_link:
            response += f"πŸ“„ **Report:** {report_link}\n\n"
        
        # Add trace URL
        if tracer.trace_url:
            response += f"πŸ” **Trace:** {tracer.trace_url}\n\n"
        
        # Add errors if any
        if errors:
            response += "**⚠️ Warnings/Errors:**\n"
            for err in errors:
                response += f"- {err}\n"
        
        # Determine preview dataframe
        if isinstance(ts_forecast_df, pd.DataFrame) and not ts_forecast_df.empty:
            preview_df = ts_forecast_df.head(100)
        elif isinstance(predict_df, pd.DataFrame) and not predict_df.empty:
            preview_df = predict_df.head(100)
        elif isinstance(sql_df, pd.DataFrame) and not sql_df.empty:
            preview_df = sql_df.head(100)
        else:
            preview_df = pd.DataFrame({"message": ["No data to display"]})
        
        return response, preview_df
        
    except Exception as e:
        error_msg = f"Unexpected error in agent execution: {str(e)}"
        logger.exception(error_msg)
        tracer.trace_event("error", {"message": error_msg})
        return f"❌ **Critical Error:** {error_msg}", pd.DataFrame()


# Gradio Interface
with gr.Blocks(title="Tabular Agentic XAI") as demo:
    gr.Markdown("""
    # πŸ€– Tabular Agentic XAI (Enterprise Edition)
    
    An intelligent assistant for analyzing tabular data with ML predictions, explanations, and time-series forecasting.
    
    **Capabilities:**
    - πŸ“Š SQL queries and data retrieval
    - 🎯 ML predictions with confidence scores
    - πŸ” SHAP-based model explanations
    - πŸ“ˆ Time-series forecasting with Granite TTM
    - πŸ“„ Automated report generation
    """)
    
    with gr.Row():
        msg = gr.Textbox(
            label="Ask your question",
            placeholder="e.g., Show me the top 10 customers by revenue, predict churn risk, forecast next quarter...",
            lines=3
        )
    
    with gr.Row():
        hitl = gr.Radio(
            ["Approve", "Needs Changes"], 
            value="Approve", 
            label="Human Review",
            info="Review the planned actions before execution"
        )
        note = gr.Textbox(
            label="Reviewer note (optional)",
            placeholder="Add any review comments...",
            lines=2
        )
    
    out_md = gr.Markdown(label="Response")
    out_df = gr.Dataframe(
        interactive=False,
        label="Data Preview (max 100 rows)",
        wrap=True
    )
    
    with gr.Row():
        ask = gr.Button("πŸš€ Run Analysis", variant="primary")
        clear = gr.Button("πŸ”„ Clear")
    
    ask.click(
        run_agent, 
        inputs=[msg, hitl, note], 
        outputs=[out_md, out_df]
    )
    
    clear.click(
        lambda: ("", "Approve", "", "", pd.DataFrame()),
        outputs=[msg, hitl, note, out_md, out_df]
    )
    
    gr.Markdown("""
    ---
    **Tips:**
    - Be specific in your queries for better results
    - Use natural language - the system will interpret your intent
    - Review the execution plan before approving
    - Check the trace link for detailed execution logs
    """)


if __name__ == "__main__":
    logger.info("Starting Gradio application...")
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )