|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import pandas as pd |
|
|
from utils.config import AppConfig |
|
|
from utils.tracing import Tracer |
|
|
|
|
|
|
|
|
class SQLTool: |
|
|
def __init__(self, cfg: AppConfig, tracer: Tracer): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self.backend = cfg.sql_backend |
|
|
|
|
|
if self.backend == "bigquery": |
|
|
from google.cloud import bigquery |
|
|
from google.oauth2 import service_account |
|
|
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON") |
|
|
if not key_json: |
|
|
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret") |
|
|
|
|
|
|
|
|
info = json.loads(key_json) if key_json.strip().startswith("{") else {} |
|
|
creds = service_account.Credentials.from_service_account_info(info) |
|
|
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project) |
|
|
|
|
|
elif self.backend == "motherduck": |
|
|
import duckdb |
|
|
|
|
|
if not duckdb.__version__.startswith("1.3.2"): |
|
|
raise RuntimeError( |
|
|
f"Incompatible DuckDB version {duckdb.__version__}. " |
|
|
"Pin duckdb==1.3.2 in requirements.txt and redeploy." |
|
|
) |
|
|
|
|
|
token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") |
|
|
db_name = (self.cfg.motherduck_db or "workspace").strip() |
|
|
if not token: |
|
|
raise RuntimeError("Missing MOTHERDUCK_TOKEN") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uri = f"md:{db_name}?motherduck_token={token}" |
|
|
self.client = duckdb.connect(uri) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
raise RuntimeError("Unknown SQL backend") |
|
|
|
|
|
def _nl_to_sql(self, message: str) -> str: |
|
|
""" |
|
|
Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt. |
|
|
Edit table/column names to your schema. |
|
|
""" |
|
|
m = message.lower() |
|
|
|
|
|
|
|
|
if "avg" in m and " by " in m: |
|
|
return ( |
|
|
"-- Example template; edit me\n" |
|
|
"SELECT DATE_TRUNC('month', date_col) AS month, " |
|
|
"AVG(metric) AS avg_metric " |
|
|
"FROM analytics.table " |
|
|
"GROUP BY 1 " |
|
|
"ORDER BY 1;" |
|
|
) |
|
|
|
|
|
|
|
|
if re.match(r"^\s*select ", m): |
|
|
return message |
|
|
|
|
|
return "SELECT * FROM analytics.table LIMIT 100;" |
|
|
|
|
|
def run(self, message: str) -> pd.DataFrame: |
|
|
sql = self._nl_to_sql(message) |
|
|
try: |
|
|
self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend}) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if self.backend == "bigquery": |
|
|
return self.client.query(sql).to_dataframe() |
|
|
else: |
|
|
|
|
|
return self.client.execute(sql).fetch_df() |
|
|
|