File size: 5,467 Bytes
b07564d f4dc602 b07564d f4dc602 e002acf 85b8a4e f4dc602 e002acf 85b8a4e f4dc602 2336094 f4dc602 e002acf b07564d e002acf f4dc602 e002acf 85b8a4e e002acf b07564d 2336094 9d6bac9 b07564d 9d6bac9 e002acf 2336094 e002acf 85b8a4e 2336094 85b8a4e 2336094 f4dc602 2336094 f4dc602 2336094 85b8a4e 2336094 85b8a4e 2336094 85b8a4e 2336094 85b8a4e 2336094 e002acf 85b8a4e e002acf 85b8a4e 2336094 f4dc602 e002acf 2336094 f4dc602 e002acf 2336094 b07564d e002acf 2336094 85b8a4e f4dc602 e002acf 2336094 e002acf f4dc602 2336094 f4dc602 b07564d e002acf f4dc602 54614e9 f4dc602 54614e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# space/tools/sql_tool.py
import os
import re
import json
import pandas as pd
from utils.config import AppConfig
from utils.tracing import Tracer
RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"} # treat these as workspace/no-DB context
class SQLTool:
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self.backend = cfg.sql_backend # "bigquery" or "motherduck"
# ---------------- BIGQUERY BACKEND ----------------
if self.backend == "bigquery":
from google.cloud import bigquery
from google.oauth2 import service_account
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
if not key_json:
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
info = json.loads(key_json) if key_json.strip().startswith("{") else {}
creds = service_account.Credentials.from_service_account_info(info)
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
# ---------------- MOTHERDUCK BACKEND ----------------
elif self.backend == "motherduck":
import duckdb
# MotherDuck extension compatibility: widely supported ABI is DuckDB 1.3.2
if not duckdb.__version__.startswith("1.3.2"):
raise RuntimeError(
f"Incompatible DuckDB version {duckdb.__version__}. "
"Pin duckdb==1.3.2 in requirements.txt and redeploy."
)
token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
db_name = (self.cfg.motherduck_db or "workspace").strip()
allow_create = (os.getenv("ALLOW_CREATE_DB", "true").lower() == "true")
if not token:
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
# Workspace vs concrete DB handling
if db_name in RESERVED_MD_WORKSPACE_NAMES:
# Connect to workspace; caller should fully-qualify tables if needed
self.client = duckdb.connect(f"md:?motherduck_token={token}")
# No USE/CREATE in workspace mode
else:
# Try direct connection to the DB (preferred)
try:
self.client = duckdb.connect(f"md:{db_name}?motherduck_token={token}")
except Exception:
# Fallback: connect to workspace, then USE/CREATE the DB if permitted
self.client = duckdb.connect(f"md:?motherduck_token={token}")
self._ensure_db_context(db_name, allow_create)
else:
raise RuntimeError(f"Unknown SQL backend: {self.backend}")
# ----- MotherDuck helpers -----
def _ensure_db_context(self, db_name: str, allow_create: bool):
"""
Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
Skips reserved workspace names.
"""
if db_name in RESERVED_MD_WORKSPACE_NAMES:
# No-op for workspace/default
return
# Attempt USE first
try:
self.client.execute(f"USE {self._quote_ident(db_name)};")
return
except Exception as use_err:
if not allow_create:
raise RuntimeError(
f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
f"Original error: {use_err}"
)
# Attempt CREATE then USE
try:
# CREATE DATABASE <name>; is supported on MotherDuck for valid names (not 'default')
self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
self.client.execute(f"USE {self._quote_ident(db_name)};")
except Exception as create_err:
raise RuntimeError(
f"Could not create or use database '{db_name}'. "
f"Original errors: CREATE: {create_err}"
)
@staticmethod
def _quote_ident(name: str) -> str:
"""
Very light identifier quoting. Replace non [a-zA-Z0-9_] with underscore.
"""
safe = re.sub(r"[^a-zA-Z0-9_]", "_", (name or ""))
return safe
# ----- NL → SQL heuristic (toy example; edit to your schema) -----
def _nl_to_sql(self, message: str) -> str:
m = message.lower()
# Example DuckDB/MotherDuck flavor of DATE_TRUNC
if "avg" in m and " by " in m:
return (
"-- Example template; edit to your schema/columns\n"
"SELECT DATE_TRUNC('month', date_col) AS month,\n"
" AVG(metric) AS avg_metric\n"
"FROM analytics.table\n"
"GROUP BY 1\n"
"ORDER BY 1;"
)
# If user typed SQL already, run it as-is
if re.match(r"^\\s*select ", m):
return message
# Fallback
return "SELECT * FROM analytics.table LIMIT 100;"
# ----- Execute -----
def run(self, message: str) -> pd.DataFrame:
sql = self._nl_to_sql(message)
try:
self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
except Exception:
pass
if self.backend == "bigquery":
return self.client.query(sql).to_dataframe()
else:
return self.client.execute(sql).fetch_df()
|