import os import re import pandas as pd from typing import Optional from utils.config import AppConfig from utils.tracing import Tracer class SQLTool: def __init__(self, cfg: AppConfig, tracer: Tracer): self.cfg = cfg self.tracer = tracer self.backend = cfg.sql_backend # "bigquery" or "motherduck" if self.backend == "bigquery": from google.cloud import bigquery from google.oauth2 import service_account key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON") if not key_json: raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret") creds = service_account.Credentials.from_service_account_info( eval(key_json) if key_json.strip().startswith("{") else {} ) self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project) elif self.backend == "motherduck": import duckdb token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") db_name = self.cfg.motherduck_db or "default" self.client = duckdb.connect(f"md:/{db_name}?motherduck_token={token}") else: raise RuntimeError("Unknown SQL backend") def _nl_to_sql(self, message: str) -> str: # Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt. # Expect users to include table names. Example: "avg revenue by month from dataset.sales" m = message.lower() if "avg" in m and " by " in m: return "-- Example template; edit me\nSELECT DATE_TRUNC(month, date_col) AS month, AVG(metric) AS avg_metric FROM dataset.table GROUP BY 1 ORDER BY 1;" # fallback: pass-through if user typed SQL explicitly if re.match(r"^\s*select ", m): return message return "SELECT * FROM dataset.table LIMIT 100;" def run(self, message: str) -> pd.DataFrame: sql = self._nl_to_sql(message) self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend}) if self.backend == "bigquery": df = self.client.query(sql).to_dataframe() else: df = self.client.execute(sql).fetch_df() return df