AshenH commited on
Commit
e81b80e
·
verified ·
1 Parent(s): 980e633

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -423
app.py CHANGED
@@ -1,443 +1,84 @@
1
- # space/app.py
2
  import os
3
- import json
4
- import logging
5
- import gradio as gr
6
  import pandas as pd
7
- from typing import Optional, Tuple
8
 
9
  from tools.sql_tool import SQLTool
10
- from tools.predict_tool import PredictTool
11
- from tools.explain_tool import ExplainTool
12
- from tools.report_tool import ReportTool
13
  from tools.ts_preprocess import build_timeseries
14
- from tools.ts_forecast_tool import TimeseriesForecastTool
15
 
16
- from utils.tracing import Tracer
17
- from utils.config import AppConfig
 
 
18
 
19
- # Configure logging
20
- logging.basicConfig(
21
- level=logging.INFO,
22
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
- )
24
- logger = logging.getLogger(__name__)
25
 
26
- # Constants
27
- MAX_RESPONSE_LENGTH = 10000
28
- MAX_FORECAST_HORIZON = 365
29
- DEFAULT_FORECAST_HORIZON = 96
30
 
31
- # Optional LLM for planning
32
- llm = None
33
- LLM_ID = os.getenv("ORCHESTRATOR_MODEL")
34
- if LLM_ID:
35
- try:
36
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
37
- logger.info(f"Loading orchestrator model: {LLM_ID}")
38
- _tok = AutoTokenizer.from_pretrained(LLM_ID)
39
- _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
40
- llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
41
- logger.info("Orchestrator model loaded successfully")
42
- except Exception as e:
43
- logger.warning(f"Failed to load orchestrator model: {e}. Using fallback planner.")
44
- llm = None
45
 
46
- # Initialize configuration and tools
47
- try:
48
- cfg = AppConfig.from_env()
49
- tracer = Tracer.from_env()
50
-
51
- sql_tool = SQLTool(cfg, tracer)
52
- predict_tool = PredictTool(cfg, tracer)
53
- explain_tool = ExplainTool(cfg, tracer)
54
- report_tool = ReportTool(cfg, tracer)
55
- ts_tool = TimeseriesForecastTool(cfg, tracer)
56
-
57
- logger.info("All tools initialized successfully")
58
- except Exception as e:
59
- logger.error(f"Failed to initialize application: {e}")
60
- raise
61
 
62
- SYSTEM_PROMPT = (
63
- "You are an analytical assistant for tabular data. "
64
- "Decide which tools to call in order: "
65
- "1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document) 5) Forecast (Granite TTM). "
66
- "Always disclose the steps taken."
67
- )
68
-
69
-
70
- def validate_message(message: str) -> Tuple[bool, str]:
71
- """Validate user input message."""
72
- if not message or not message.strip():
73
- return False, "Please enter a valid question."
74
-
75
- if len(message) > MAX_RESPONSE_LENGTH:
76
- return False, f"Message too long. Please limit to {MAX_RESPONSE_LENGTH} characters."
77
-
78
- # Basic SQL injection pattern detection
79
- suspicious_patterns = [
80
- r';\s*drop\s+table',
81
- r';\s*delete\s+from',
82
- r';\s*truncate',
83
- r'union\s+select.*from',
84
- r'exec\s*\(',
85
- r'execute\s*\('
86
- ]
87
-
88
- import re
89
- msg_lower = message.lower()
90
- for pattern in suspicious_patterns:
91
- if re.search(pattern, msg_lower):
92
- logger.warning(f"Suspicious SQL pattern detected: {pattern}")
93
- return False, "Query contains potentially unsafe patterns. Please rephrase."
94
-
95
- return True, ""
96
 
 
 
 
 
 
97
 
98
- def plan_actions(message: str) -> dict:
99
- """
100
- Determine which tools to execute based on the user message.
101
- Uses LLM if available, otherwise falls back to heuristics.
102
- """
103
- if llm is not None:
104
- prompt = (
105
- f"{SYSTEM_PROMPT}\nUser: {message}\n"
106
- "Return JSON with fields: steps (array subset of ['sql','predict','explain','report','forecast']), rationale."
107
- )
108
- try:
109
- out = llm(prompt)[0]["generated_text"]
110
- last = out.split("\n")[-1].strip()
111
- obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):])
112
-
113
- if isinstance(obj, dict) and "steps" in obj:
114
- # Validate steps
115
- valid_steps = {'sql', 'predict', 'explain', 'report', 'forecast'}
116
- obj["steps"] = [s for s in obj["steps"] if s in valid_steps]
117
- if obj["steps"]:
118
- logger.info(f"LLM plan: {obj['steps']}")
119
- return obj
120
- except json.JSONDecodeError as e:
121
- logger.warning(f"Failed to parse LLM output as JSON: {e}")
122
- except Exception as e:
123
- logger.warning(f"LLM planning failed: {e}")
124
-
125
- # Fallback heuristic planning
126
- m = message.lower()
127
- steps = []
128
-
129
- # SQL keywords
130
- if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query",
131
- "kpi", "data", "retrieve", "fetch", "list", "view"]):
132
- steps.append("sql")
133
-
134
- # Prediction keywords
135
- if any(k in m for k in ["predict", "score", "risk", "propensity", "probability",
136
- "classification", "regression"]):
137
- steps.append("predict")
138
- if "sql" not in steps:
139
- steps.insert(0, "sql") # Need data first
140
-
141
- # Explanation keywords
142
- if any(k in m for k in ["why", "explain", "shap", "feature", "attribution",
143
- "importance", "interpret"]):
144
- steps.append("explain")
145
- if "predict" not in steps:
146
- steps.insert(0, "predict")
147
- if "sql" not in steps:
148
- steps.insert(0, "sql")
149
-
150
- # Report keywords
151
- if any(k in m for k in ["report", "download", "pdf", "summary", "document", "export"]):
152
- steps.append("report")
153
-
154
- # Forecast keywords
155
- if any(k in m for k in ["forecast", "next", "horizon", "granite", "predict future",
156
- "time series", "timeseries"]):
157
- steps.append("forecast")
158
- if "sql" not in steps:
159
- steps.insert(0, "sql")
160
-
161
- # Default to SQL if no steps identified
162
- if not steps:
163
- steps = ["sql"]
164
-
165
- rationale = f"Rule-based plan based on keywords: {', '.join(steps)}"
166
- logger.info(f"Heuristic plan: {steps}")
167
- return {"steps": steps, "rationale": rationale}
168
 
 
 
 
169
 
170
- def run_agent(
171
- message: str,
172
- hitl_decision: str = "Approve",
173
- reviewer_note: str = ""
174
- ) -> Tuple[str, pd.DataFrame]:
175
- """
176
- Main agent execution function.
177
-
178
- Args:
179
- message: User query
180
- hitl_decision: Human-in-the-loop decision
181
- reviewer_note: Optional review notes
182
-
183
- Returns:
184
- Tuple of (response_text, preview_dataframe)
185
- """
186
  try:
187
- # Validate input
188
- is_valid, error_msg = validate_message(message)
189
- if not is_valid:
190
- logger.warning(f"Invalid message: {error_msg}")
191
- return f"❌ **Error:** {error_msg}", pd.DataFrame()
192
-
193
- tracer.trace_event("user_message", {"message": message[:500]}) # Limit traced message length
194
-
195
- # Plan actions
196
- try:
197
- plan = plan_actions(message)
198
- tracer.trace_event("plan", plan)
199
- except Exception as e:
200
- logger.error(f"Planning failed: {e}")
201
- return f"❌ **Planning Error:** Unable to create execution plan. {str(e)}", pd.DataFrame()
202
-
203
- # Initialize result containers
204
- sql_df = None
205
- predict_df = None
206
- explain_imgs = {}
207
- artifacts = {}
208
- ts_forecast_df = None
209
- errors = []
210
-
211
- # Execute SQL step
212
- if "sql" in plan["steps"]:
213
- try:
214
- sql_df = sql_tool.run(message)
215
- if isinstance(sql_df, pd.DataFrame):
216
- artifacts["sql_rows"] = len(sql_df)
217
- logger.info(f"SQL returned {len(sql_df)} rows")
218
- else:
219
- errors.append("SQL query returned no data")
220
- except Exception as e:
221
- error_msg = f"SQL execution failed: {str(e)}"
222
- logger.error(error_msg)
223
- errors.append(error_msg)
224
-
225
- # Execute prediction step
226
- if "predict" in plan["steps"]:
227
- try:
228
- if sql_df is not None and not sql_df.empty:
229
- predict_df = predict_tool.run(sql_df)
230
- if isinstance(predict_df, pd.DataFrame):
231
- artifacts["predict_rows"] = len(predict_df)
232
- logger.info(f"Predictions generated for {len(predict_df)} rows")
233
- else:
234
- errors.append("Prediction skipped: no data available")
235
- except Exception as e:
236
- error_msg = f"Prediction failed: {str(e)}"
237
- logger.error(error_msg)
238
- errors.append(error_msg)
239
-
240
- # Build time series if possible
241
- ts_df = None
242
- if sql_df is not None and not sql_df.empty:
243
- try:
244
- ts_df = build_timeseries(sql_df)
245
- logger.info(f"Time series built with {len(ts_df)} records")
246
- except Exception as e:
247
- logger.info(f"Time series preprocessing skipped: {e}")
248
- # Not always an error - data might not be suitable for TS
249
-
250
- # Execute forecast step
251
- if "forecast" in plan["steps"]:
252
- if ts_df is not None and not ts_df.empty:
253
- try:
254
- # Aggregate portfolio value by timestamp
255
- agg = ts_df.groupby("timestamp", as_index=True)["portfolio_value"].sum().sort_index()
256
-
257
- if len(agg) < 2:
258
- errors.append("Insufficient time series data for forecasting (need at least 2 points)")
259
- else:
260
- # Validate horizon
261
- horizon = min(DEFAULT_FORECAST_HORIZON, MAX_FORECAST_HORIZON)
262
- ts_forecast_df = ts_tool.zeroshot_forecast(agg, horizon=horizon)
263
-
264
- if isinstance(ts_forecast_df, pd.DataFrame):
265
- if "error" in ts_forecast_df.columns:
266
- errors.append(f"Forecast error: {ts_forecast_df['error'].iloc[0]}")
267
- ts_forecast_df = None
268
- else:
269
- artifacts["forecast_horizon"] = len(ts_forecast_df)
270
- logger.info(f"Forecast generated for {len(ts_forecast_df)} periods")
271
- except Exception as e:
272
- error_msg = f"Forecasting failed: {str(e)}"
273
- logger.error(error_msg)
274
- errors.append(error_msg)
275
- else:
276
- errors.append("Forecast skipped: no suitable time series data")
277
-
278
- # Execute explanation step
279
- if "explain" in plan["steps"]:
280
- try:
281
- explain_data = predict_df if predict_df is not None else sql_df
282
- if explain_data is not None and not explain_data.empty:
283
- explain_imgs = explain_tool.run(explain_data)
284
- artifacts["explain_charts"] = len(explain_imgs)
285
- logger.info(f"Generated {len(explain_imgs)} explanation charts")
286
- else:
287
- errors.append("Explanation skipped: no data available")
288
- except Exception as e:
289
- error_msg = f"Explanation failed: {str(e)}"
290
- logger.error(error_msg)
291
- errors.append(error_msg)
292
-
293
- # Execute report generation
294
- report_link = None
295
- if "report" in plan["steps"]:
296
- try:
297
- forecast_preview = ts_forecast_df.head(50) if isinstance(ts_forecast_df, pd.DataFrame) else None
298
- report_link = report_tool.render_and_save(
299
- user_query=message,
300
- sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
301
- predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else forecast_preview,
302
- explain_images=explain_imgs,
303
- plan=plan,
304
- )
305
- logger.info(f"Report generated: {report_link}")
306
- except Exception as e:
307
- error_msg = f"Report generation failed: {str(e)}"
308
- logger.error(error_msg)
309
- errors.append(error_msg)
310
-
311
- # Log human-in-the-loop decision
312
- tracer.trace_event("hitl", {
313
- "message": message[:500],
314
- "decision": hitl_decision,
315
- "reviewer_note": reviewer_note[:500] if reviewer_note else "",
316
- "artifacts": artifacts,
317
- "plan": plan,
318
- "errors": errors,
319
- })
320
-
321
- # Compose response
322
- response = f"**Plan:** {', '.join(plan['steps'])}\n\n**Rationale:** {plan['rationale']}\n\n"
323
-
324
- # Add artifacts info
325
- if artifacts:
326
- response += "**Results:**\n"
327
- if "sql_rows" in artifacts:
328
- response += f"- SQL query returned {artifacts['sql_rows']} rows\n"
329
- if "predict_rows" in artifacts:
330
- response += f"- Generated predictions for {artifacts['predict_rows']} rows\n"
331
- if "forecast_horizon" in artifacts:
332
- response += f"- Forecast generated for {artifacts['forecast_horizon']} periods\n"
333
- if "explain_charts" in artifacts:
334
- response += f"- Created {artifacts['explain_charts']} explanation charts\n"
335
- response += "\n"
336
-
337
- # Add report link
338
- if report_link:
339
- response += f"📄 **Report:** {report_link}\n\n"
340
-
341
- # Add trace URL
342
- if tracer.trace_url:
343
- response += f"🔍 **Trace:** {tracer.trace_url}\n\n"
344
-
345
- # Add errors if any
346
- if errors:
347
- response += "**⚠️ Warnings/Errors:**\n"
348
- for err in errors:
349
- response += f"- {err}\n"
350
-
351
- # Determine preview dataframe
352
- if isinstance(ts_forecast_df, pd.DataFrame) and not ts_forecast_df.empty:
353
- preview_df = ts_forecast_df.head(100)
354
- elif isinstance(predict_df, pd.DataFrame) and not predict_df.empty:
355
- preview_df = predict_df.head(100)
356
- elif isinstance(sql_df, pd.DataFrame) and not sql_df.empty:
357
- preview_df = sql_df.head(100)
358
- else:
359
- preview_df = pd.DataFrame({"message": ["No data to display"]})
360
-
361
- return response, preview_df
362
-
363
  except Exception as e:
364
- error_msg = f"Unexpected error in agent execution: {str(e)}"
365
- logger.exception(error_msg)
366
- tracer.trace_event("error", {"message": error_msg})
367
- return f"❌ **Critical Error:** {error_msg}", pd.DataFrame()
368
-
369
-
370
- # Gradio Interface
371
- with gr.Blocks(title="Tabular Agentic XAI") as demo:
372
- gr.Markdown("""
373
- # 🤖 Tabular Agentic XAI (Enterprise Edition)
374
-
375
- An intelligent assistant for analyzing tabular data with ML predictions, explanations, and time-series forecasting.
376
-
377
- **Capabilities:**
378
- - 📊 SQL queries and data retrieval
379
- - 🎯 ML predictions with confidence scores
380
- - 🔍 SHAP-based model explanations
381
- - 📈 Time-series forecasting with Granite TTM
382
- - 📄 Automated report generation
383
- """)
384
-
385
- with gr.Row():
386
- msg = gr.Textbox(
387
- label="Ask your question",
388
- placeholder="e.g., Show me the top 10 customers by revenue, predict churn risk, forecast next quarter...",
389
- lines=3
390
- )
391
-
392
- with gr.Row():
393
- hitl = gr.Radio(
394
- ["Approve", "Needs Changes"],
395
- value="Approve",
396
- label="Human Review",
397
- info="Review the planned actions before execution"
398
- )
399
- note = gr.Textbox(
400
- label="Reviewer note (optional)",
401
- placeholder="Add any review comments...",
402
- lines=2
403
- )
404
-
405
- out_md = gr.Markdown(label="Response")
406
- out_df = gr.Dataframe(
407
- interactive=False,
408
- label="Data Preview (max 100 rows)",
409
- wrap=True
410
- )
411
-
412
- with gr.Row():
413
- ask = gr.Button("🚀 Run Analysis", variant="primary")
414
- clear = gr.Button("🔄 Clear")
415
-
416
- ask.click(
417
- run_agent,
418
- inputs=[msg, hitl, note],
419
- outputs=[out_md, out_df]
420
- )
421
-
422
- clear.click(
423
- lambda: ("", "Approve", "", "", pd.DataFrame()),
424
- outputs=[msg, hitl, note, out_md, out_df]
425
- )
426
-
427
- gr.Markdown("""
428
- ---
429
- **Tips:**
430
- - Be specific in your queries for better results
431
- - Use natural language - the system will interpret your intent
432
- - Review the execution plan before approving
433
- - Check the trace link for detailed execution logs
434
- """)
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
  if __name__ == "__main__":
438
- logger.info("Starting Gradio application...")
439
- demo.launch(
440
- server_name="0.0.0.0",
441
- server_port=7860,
442
- show_error=True
443
- )
 
1
+ # app.py
2
  import os
 
 
 
3
  import pandas as pd
4
+ import gradio as gr
5
 
6
  from tools.sql_tool import SQLTool
 
 
 
7
  from tools.ts_preprocess import build_timeseries
 
8
 
9
+ # Ensure DB path & defaults (you can set these in Space Settings → Variables)
10
+ DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
11
+ DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main")
12
+ DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v")
13
 
14
+ sql_tool = SQLTool(DUCKDB_PATH)
 
 
 
 
 
15
 
16
+ INTRO = f"""
17
+ ### ALM LLM — Demo
18
+ Connected to **DuckDB** at `{DUCKDB_PATH}` using table **{DEFAULT_SCHEMA}.{DEFAULT_TABLE}**.
 
19
 
20
+ **Try:**
21
+ - *"show me the top 10 fds by portfolio value"*
22
+ - *"top 10 assets by portfolio value"*
23
+ - *"sum portfolio value by currency"*
24
+ """
 
 
 
 
 
 
 
 
 
25
 
26
+ def run_nl(nl_query: str):
27
+ if not nl_query or not nl_query.strip():
28
+ return pd.DataFrame(), "", "Please enter a query.", pd.DataFrame(), pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ try:
31
+ df, sql, why = sql_tool.query_from_nl(nl_query)
32
+ except Exception as e:
33
+ return pd.DataFrame(), "", f"Error: {e}", pd.DataFrame(), pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Try to build timeseries cashflows + gap if columns match masterdataset_v
36
+ try:
37
+ cf, gap = build_timeseries(df)
38
+ except Exception as e:
39
+ cf, gap = pd.DataFrame(), pd.DataFrame()
40
 
41
+ return df, sql.strip(), why, cf, gap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ def run_sql(sql_text: str):
44
+ if not sql_text or not sql_text.strip():
45
+ return pd.DataFrame(), "Please paste a SQL statement.", pd.DataFrame(), pd.DataFrame()
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
+ df = sql_tool.run_sql(sql_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  except Exception as e:
50
+ return pd.DataFrame(), f"Error: {e}", pd.DataFrame(), pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ try:
53
+ cf, gap = build_timeseries(df)
54
+ except Exception:
55
+ cf, gap = pd.DataFrame(), pd.DataFrame()
56
+ return df, "OK", cf, gap
57
+
58
+ with gr.Blocks(title="ALM LLM") as demo:
59
+ gr.Markdown(INTRO)
60
+
61
+ with gr.Tab("Ask in Natural Language"):
62
+ nl = gr.Textbox(label="Ask a question", placeholder="e.g., show me the top 10 fds by portfolio value")
63
+ btn = gr.Button("Run")
64
+ sql_out = gr.Textbox(label="Generated SQL", interactive=False)
65
+ why_out = gr.Textbox(label="Reasoning", interactive=False)
66
+ df_out = gr.Dataframe(label="Query Result", wrap=True)
67
+ cf_out = gr.Dataframe(label="Projected Cash-Flows (if applicable)", wrap=True, height=250)
68
+ gap_out = gr.Dataframe(label="Liquidity Gap (monthly)", wrap=True, height=200)
69
+
70
+ btn.click(fn=run_nl, inputs=[nl], outputs=[df_out, sql_out, why_out, cf_out, gap_out])
71
+
72
+ with gr.Tab("Run Raw SQL"):
73
+ sql_in = gr.Code(label="SQL", language="sql", value=f"SELECT * FROM {DEFAULT_SCHEMA}.{DEFAULT_TABLE} LIMIT 20;")
74
+ btn2 = gr.Button("Execute")
75
+ df2 = gr.Dataframe(label="Result", wrap=True)
76
+ status = gr.Textbox(label="Status", interactive=False)
77
+ cf2 = gr.Dataframe(label="Projected Cash-Flows (if applicable)", wrap=True, height=250)
78
+ gap2 = gr.Dataframe(label="Liquidity Gap (monthly)", wrap=True, height=200)
79
+
80
+ btn2.click(fn=run_sql, inputs=[sql_in], outputs=[df2, status, cf2, gap2])
81
 
82
  if __name__ == "__main__":
83
+ # Spaces set PORT automatically; otherwise, Gradio defaults are fine.
84
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))