# tools/sql_tool.py import os import re import duckdb from typing import Optional, Tuple DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb") # Defaults point to your real table; can be overridden via Space secrets DEFAULT_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main") DEFAULT_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v") def _full_table(schema: Optional[str] = None, table: Optional[str] = None) -> str: schema = schema or DEFAULT_SCHEMA table = table or DEFAULT_TABLE return f"{schema}.{table}" class SQLTool: """ Minimal NL→SQL helper wired to main.masterdataset_v with a DuckDB runner. """ def __init__(self, db_path: Optional[str] = None): self.db_path = db_path or DUCKDB_PATH self.con = duckdb.connect(self.db_path) def run_sql(self, sql: str): return self.con.execute(sql).df() # ------------------------- # NL → SQL # ------------------------- def _nl_to_sql(self, message: str, schema: Optional[str] = None, table: Optional[str] = None) -> Tuple[str, str]: """ Returns (sql, rationale). Very small template library covering your common queries. Falls back to SHOW TABLES if no match. """ full_table = _full_table(schema, table) m = message.strip().lower() # Common synonyms def has_any(txt, words): return any(w in txt for w in words) # Extract a "top N" limit = None m_top = re.search(r"\btop\s+(\d+)", m) if m_top: limit = int(m_top.group(1)) # 1) Top N FDs by Portfolio_value if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any(m, ["top", "largest", "biggest"]) and has_any(m, ["portfolio value", "portfolio_value"]): n = limit or 10 sql = f""" SELECT contract_number, Portfolio_value, Interest_rate, currency, segments FROM {full_table} WHERE lower(product) = 'fd' ORDER BY Portfolio_value DESC LIMIT {n}; """ why = f"Top {n} fixed deposits by Portfolio_value from {full_table}" return sql, why # 2) Top N Assets by Portfolio_value if has_any(m, ["asset", "loan", "advances"]) and has_any(m, ["top", "largest", "biggest"]) and has_any(m, ["portfolio value", "portfolio_value"]): n = limit or 10 sql = f""" SELECT contract_number, Portfolio_value, Interest_rate, currency, segments FROM {full_table} WHERE lower(product) = 'assets' ORDER BY Portfolio_value DESC LIMIT {n}; """ why = f"Top {n} assets by Portfolio_value from {full_table}" return sql, why # 3) Aggregate (SUM/AVG) by segment or currency if has_any(m, ["sum", "total", "avg", "average"]) and has_any(m, ["segment", "currency"]): agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG" dim = "segments" if "segment" in m else "currency" sql = f""" SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value FROM {full_table} GROUP BY 1 ORDER BY 2 DESC; """ why = f"{agg} Portfolio_value grouped by {dim} from {full_table}" return sql, why # 4) Filter by product, currency, or segment product = None if "fd" in m or "deposit" in m: product = "fd" elif "asset" in m or "loan" in m or "advance" in m: product = "assets" parts = [f"SELECT * FROM {full_table} WHERE 1=1"] why_parts = [f"Filtered rows from {full_table}"] if product: parts.append(f"AND lower(product) = '{product}'") why_parts.append(f"product = {product}") # currency filter like: "in lkr", "currency usd" cur = None cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m) if cur_match: cur = cur_match.group(2).upper() if cur: parts.append(f"AND upper(currency) = '{cur}'") why_parts.append(f"currency = {cur}") # segment filter like: "segment retail" or "for corporate" seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m) if seg_match: seg = seg_match.group(2).strip() if seg: parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'") why_parts.append(f"segments like '{seg}'") # maybe a limit if limit: parts.append(f"LIMIT {limit}") fallback_sql = " ".join(parts) + ";" fallback_why = "; ".join(why_parts) if fallback_sql: return fallback_sql, fallback_why # 5) Super fallback: show sample rows return f"SELECT * FROM {full_table} LIMIT 20;", f"Default sample from {full_table}" # Public helpers def query_from_nl(self, message: str): sql, why = self._nl_to_sql(message) df = self.run_sql(sql) return df, sql, why def table_exists(self, schema: Optional[str] = None, table: Optional[str] = None) -> bool: schema = schema or DEFAULT_SCHEMA table = table or DEFAULT_TABLE q = f"SELECT COUNT(*) AS n FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}';" n = self.con.execute(q).fetchone()[0] return n > 0