ALM_LLM / tools /sql_tool.py
AshenH's picture
Update tools/sql_tool.py
85b8a4e verified
raw
history blame
5.47 kB
# space/tools/sql_tool.py
import os
import re
import json
import pandas as pd
from utils.config import AppConfig
from utils.tracing import Tracer
RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"} # treat these as workspace/no-DB context
class SQLTool:
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self.backend = cfg.sql_backend # "bigquery" or "motherduck"
# ---------------- BIGQUERY BACKEND ----------------
if self.backend == "bigquery":
from google.cloud import bigquery
from google.oauth2 import service_account
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
if not key_json:
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
info = json.loads(key_json) if key_json.strip().startswith("{") else {}
creds = service_account.Credentials.from_service_account_info(info)
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
# ---------------- MOTHERDUCK BACKEND ----------------
elif self.backend == "motherduck":
import duckdb
# MotherDuck extension compatibility: widely supported ABI is DuckDB 1.3.2
if not duckdb.__version__.startswith("1.3.2"):
raise RuntimeError(
f"Incompatible DuckDB version {duckdb.__version__}. "
"Pin duckdb==1.3.2 in requirements.txt and redeploy."
)
token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
db_name = (self.cfg.motherduck_db or "workspace").strip()
allow_create = (os.getenv("ALLOW_CREATE_DB", "true").lower() == "true")
if not token:
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
# Workspace vs concrete DB handling
if db_name in RESERVED_MD_WORKSPACE_NAMES:
# Connect to workspace; caller should fully-qualify tables if needed
self.client = duckdb.connect(f"md:?motherduck_token={token}")
# No USE/CREATE in workspace mode
else:
# Try direct connection to the DB (preferred)
try:
self.client = duckdb.connect(f"md:{db_name}?motherduck_token={token}")
except Exception:
# Fallback: connect to workspace, then USE/CREATE the DB if permitted
self.client = duckdb.connect(f"md:?motherduck_token={token}")
self._ensure_db_context(db_name, allow_create)
else:
raise RuntimeError(f"Unknown SQL backend: {self.backend}")
# ----- MotherDuck helpers -----
def _ensure_db_context(self, db_name: str, allow_create: bool):
"""
Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
Skips reserved workspace names.
"""
if db_name in RESERVED_MD_WORKSPACE_NAMES:
# No-op for workspace/default
return
# Attempt USE first
try:
self.client.execute(f"USE {self._quote_ident(db_name)};")
return
except Exception as use_err:
if not allow_create:
raise RuntimeError(
f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
f"Original error: {use_err}"
)
# Attempt CREATE then USE
try:
# CREATE DATABASE <name>; is supported on MotherDuck for valid names (not 'default')
self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
self.client.execute(f"USE {self._quote_ident(db_name)};")
except Exception as create_err:
raise RuntimeError(
f"Could not create or use database '{db_name}'. "
f"Original errors: CREATE: {create_err}"
)
@staticmethod
def _quote_ident(name: str) -> str:
"""
Very light identifier quoting. Replace non [a-zA-Z0-9_] with underscore.
"""
safe = re.sub(r"[^a-zA-Z0-9_]", "_", (name or ""))
return safe
# ----- NL → SQL heuristic (toy example; edit to your schema) -----
def _nl_to_sql(self, message: str) -> str:
m = message.lower()
# Example DuckDB/MotherDuck flavor of DATE_TRUNC
if "avg" in m and " by " in m:
return (
"-- Example template; edit to your schema/columns\n"
"SELECT DATE_TRUNC('month', date_col) AS month,\n"
" AVG(metric) AS avg_metric\n"
"FROM analytics.table\n"
"GROUP BY 1\n"
"ORDER BY 1;"
)
# If user typed SQL already, run it as-is
if re.match(r"^\\s*select ", m):
return message
# Fallback
return "SELECT * FROM analytics.table LIMIT 100;"
# ----- Execute -----
def run(self, message: str) -> pd.DataFrame:
sql = self._nl_to_sql(message)
try:
self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
except Exception:
pass
if self.backend == "bigquery":
return self.client.query(sql).to_dataframe()
else:
return self.client.execute(sql).fetch_df()