|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import pandas as pd |
|
|
from utils.config import AppConfig |
|
|
from utils.tracing import Tracer |
|
|
|
|
|
|
|
|
RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"} |
|
|
|
|
|
|
|
|
class SQLTool: |
|
|
def __init__(self, cfg: AppConfig, tracer: Tracer): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self.backend = cfg.sql_backend |
|
|
|
|
|
|
|
|
if self.backend == "bigquery": |
|
|
from google.cloud import bigquery |
|
|
from google.oauth2 import service_account |
|
|
|
|
|
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON") |
|
|
if not key_json: |
|
|
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret") |
|
|
|
|
|
info = json.loads(key_json) if key_json.strip().startswith("{") else {} |
|
|
creds = service_account.Credentials.from_service_account_info(info) |
|
|
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project) |
|
|
|
|
|
|
|
|
elif self.backend == "motherduck": |
|
|
import duckdb |
|
|
|
|
|
|
|
|
if not duckdb.__version__.startswith("1.3.2"): |
|
|
raise RuntimeError( |
|
|
f"Incompatible DuckDB version {duckdb.__version__}. " |
|
|
"Pin duckdb==1.3.2 in requirements.txt and redeploy." |
|
|
) |
|
|
|
|
|
token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip() |
|
|
db_name = (self.cfg.motherduck_db or "workspace").strip() |
|
|
allow_create = (os.getenv("ALLOW_CREATE_DB", "true").lower() == "true") |
|
|
if not token: |
|
|
raise RuntimeError("Missing MOTHERDUCK_TOKEN") |
|
|
|
|
|
|
|
|
if db_name in RESERVED_MD_WORKSPACE_NAMES: |
|
|
|
|
|
self.client = duckdb.connect(f"md:?motherduck_token={token}") |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
self.client = duckdb.connect(f"md:{db_name}?motherduck_token={token}") |
|
|
except Exception: |
|
|
|
|
|
self.client = duckdb.connect(f"md:?motherduck_token={token}") |
|
|
self._ensure_db_context(db_name, allow_create) |
|
|
|
|
|
else: |
|
|
raise RuntimeError(f"Unknown SQL backend: {self.backend}") |
|
|
|
|
|
|
|
|
def _ensure_db_context(self, db_name: str, allow_create: bool): |
|
|
""" |
|
|
Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it. |
|
|
Skips reserved workspace names. |
|
|
""" |
|
|
if db_name in RESERVED_MD_WORKSPACE_NAMES: |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
self.client.execute(f"USE {self._quote_ident(db_name)};") |
|
|
return |
|
|
except Exception as use_err: |
|
|
if not allow_create: |
|
|
raise RuntimeError( |
|
|
f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. " |
|
|
f"Original error: {use_err}" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};") |
|
|
self.client.execute(f"USE {self._quote_ident(db_name)};") |
|
|
except Exception as create_err: |
|
|
raise RuntimeError( |
|
|
f"Could not create or use database '{db_name}'. " |
|
|
f"Original errors: CREATE: {create_err}" |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _quote_ident(name: str) -> str: |
|
|
""" |
|
|
Very light identifier quoting. Replace non [a-zA-Z0-9_] with underscore. |
|
|
""" |
|
|
safe = re.sub(r"[^a-zA-Z0-9_]", "_", (name or "")) |
|
|
return safe |
|
|
|
|
|
|
|
|
def _nl_to_sql(self, message: str) -> str: |
|
|
m = message.lower() |
|
|
|
|
|
|
|
|
if "avg" in m and " by " in m: |
|
|
return ( |
|
|
"-- Example template; edit to your schema/columns\n" |
|
|
"SELECT DATE_TRUNC('month', date_col) AS month,\n" |
|
|
" AVG(metric) AS avg_metric\n" |
|
|
"FROM analytics.table\n" |
|
|
"GROUP BY 1\n" |
|
|
"ORDER BY 1;" |
|
|
) |
|
|
|
|
|
|
|
|
if re.match(r"^\\s*select ", m): |
|
|
return message |
|
|
|
|
|
|
|
|
return "SELECT * FROM analytics.table LIMIT 100;" |
|
|
|
|
|
|
|
|
def run(self, message: str) -> pd.DataFrame: |
|
|
sql = self._nl_to_sql(message) |
|
|
try: |
|
|
self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend}) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if self.backend == "bigquery": |
|
|
return self.client.query(sql).to_dataframe() |
|
|
else: |
|
|
return self.client.execute(sql).fetch_df() |
|
|
|