Update tools/sql_tool.py
Browse files- tools/sql_tool.py +37 -31
tools/sql_tool.py
CHANGED
|
@@ -7,12 +7,16 @@ from utils.config import AppConfig
|
|
| 7 |
from utils.tracing import Tracer
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
class SQLTool:
|
| 11 |
def __init__(self, cfg: AppConfig, tracer: Tracer):
|
| 12 |
self.cfg = cfg
|
| 13 |
self.tracer = tracer
|
| 14 |
self.backend = cfg.sql_backend # "bigquery" or "motherduck"
|
| 15 |
|
|
|
|
| 16 |
if self.backend == "bigquery":
|
| 17 |
from google.cloud import bigquery
|
| 18 |
from google.oauth2 import service_account
|
|
@@ -25,6 +29,7 @@ class SQLTool:
|
|
| 25 |
creds = service_account.Credentials.from_service_account_info(info)
|
| 26 |
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
|
| 27 |
|
|
|
|
| 28 |
elif self.backend == "motherduck":
|
| 29 |
import duckdb
|
| 30 |
|
|
@@ -41,22 +46,18 @@ class SQLTool:
|
|
| 41 |
if not token:
|
| 42 |
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
uri = f"md:{db_name}?motherduck_token={token}" if db_name and db_name != "workspace" else f"md:?motherduck_token={token}"
|
| 48 |
-
self.client = duckdb.connect(uri)
|
| 49 |
-
# If we connected to workspace explicitly, set DB context if provided
|
| 50 |
-
if db_name and db_name != "workspace":
|
| 51 |
-
# Ensure we are actually in the right DB context
|
| 52 |
-
self._ensure_db_context(db_name, allow_create)
|
| 53 |
-
except Exception as e:
|
| 54 |
-
# Fallback: connect to workspace, then create/use DB if needed
|
| 55 |
self.client = duckdb.connect(f"md:?motherduck_token={token}")
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
self._ensure_db_context(db_name, allow_create)
|
| 61 |
|
| 62 |
else:
|
|
@@ -66,35 +67,40 @@ class SQLTool:
|
|
| 66 |
def _ensure_db_context(self, db_name: str, allow_create: bool):
|
| 67 |
"""
|
| 68 |
Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
|
|
|
|
| 69 |
"""
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
try:
|
| 72 |
self.client.execute(f"USE {self._quote_ident(db_name)};")
|
|
|
|
| 73 |
except Exception as use_err:
|
| 74 |
if not allow_create:
|
| 75 |
raise RuntimeError(
|
| 76 |
f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
|
| 77 |
f"Original error: {use_err}"
|
| 78 |
)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
|
| 89 |
@staticmethod
|
| 90 |
def _quote_ident(name: str) -> str:
|
| 91 |
"""
|
| 92 |
-
Very light identifier quoting.
|
| 93 |
"""
|
| 94 |
-
|
| 95 |
-
return name
|
| 96 |
-
# basic guard; you can tighten rules for your org naming conventions
|
| 97 |
-
safe = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
| 98 |
return safe
|
| 99 |
|
| 100 |
# ----- NL → SQL heuristic (toy example; edit to your schema) -----
|
|
@@ -113,7 +119,7 @@ class SQLTool:
|
|
| 113 |
)
|
| 114 |
|
| 115 |
# If user typed SQL already, run it as-is
|
| 116 |
-
if re.match(r"
|
| 117 |
return message
|
| 118 |
|
| 119 |
# Fallback
|
|
|
|
| 7 |
from utils.tracing import Tracer
|
| 8 |
|
| 9 |
|
| 10 |
+
RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"} # treat these as workspace/no-DB context
|
| 11 |
+
|
| 12 |
+
|
| 13 |
class SQLTool:
|
| 14 |
def __init__(self, cfg: AppConfig, tracer: Tracer):
|
| 15 |
self.cfg = cfg
|
| 16 |
self.tracer = tracer
|
| 17 |
self.backend = cfg.sql_backend # "bigquery" or "motherduck"
|
| 18 |
|
| 19 |
+
# ---------------- BIGQUERY BACKEND ----------------
|
| 20 |
if self.backend == "bigquery":
|
| 21 |
from google.cloud import bigquery
|
| 22 |
from google.oauth2 import service_account
|
|
|
|
| 29 |
creds = service_account.Credentials.from_service_account_info(info)
|
| 30 |
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
|
| 31 |
|
| 32 |
+
# ---------------- MOTHERDUCK BACKEND ----------------
|
| 33 |
elif self.backend == "motherduck":
|
| 34 |
import duckdb
|
| 35 |
|
|
|
|
| 46 |
if not token:
|
| 47 |
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
|
| 48 |
|
| 49 |
+
# Workspace vs concrete DB handling
|
| 50 |
+
if db_name in RESERVED_MD_WORKSPACE_NAMES:
|
| 51 |
+
# Connect to workspace; caller should fully-qualify tables if needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
self.client = duckdb.connect(f"md:?motherduck_token={token}")
|
| 53 |
+
# No USE/CREATE in workspace mode
|
| 54 |
+
else:
|
| 55 |
+
# Try direct connection to the DB (preferred)
|
| 56 |
+
try:
|
| 57 |
+
self.client = duckdb.connect(f"md:{db_name}?motherduck_token={token}")
|
| 58 |
+
except Exception:
|
| 59 |
+
# Fallback: connect to workspace, then USE/CREATE the DB if permitted
|
| 60 |
+
self.client = duckdb.connect(f"md:?motherduck_token={token}")
|
| 61 |
self._ensure_db_context(db_name, allow_create)
|
| 62 |
|
| 63 |
else:
|
|
|
|
| 67 |
def _ensure_db_context(self, db_name: str, allow_create: bool):
|
| 68 |
"""
|
| 69 |
Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
|
| 70 |
+
Skips reserved workspace names.
|
| 71 |
"""
|
| 72 |
+
if db_name in RESERVED_MD_WORKSPACE_NAMES:
|
| 73 |
+
# No-op for workspace/default
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
# Attempt USE first
|
| 77 |
try:
|
| 78 |
self.client.execute(f"USE {self._quote_ident(db_name)};")
|
| 79 |
+
return
|
| 80 |
except Exception as use_err:
|
| 81 |
if not allow_create:
|
| 82 |
raise RuntimeError(
|
| 83 |
f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
|
| 84 |
f"Original error: {use_err}"
|
| 85 |
)
|
| 86 |
+
|
| 87 |
+
# Attempt CREATE then USE
|
| 88 |
+
try:
|
| 89 |
+
# CREATE DATABASE <name>; is supported on MotherDuck for valid names (not 'default')
|
| 90 |
+
self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
|
| 91 |
+
self.client.execute(f"USE {self._quote_ident(db_name)};")
|
| 92 |
+
except Exception as create_err:
|
| 93 |
+
raise RuntimeError(
|
| 94 |
+
f"Could not create or use database '{db_name}'. "
|
| 95 |
+
f"Original errors: CREATE: {create_err}"
|
| 96 |
+
)
|
| 97 |
|
| 98 |
@staticmethod
|
| 99 |
def _quote_ident(name: str) -> str:
|
| 100 |
"""
|
| 101 |
+
Very light identifier quoting. Replace non [a-zA-Z0-9_] with underscore.
|
| 102 |
"""
|
| 103 |
+
safe = re.sub(r"[^a-zA-Z0-9_]", "_", (name or ""))
|
|
|
|
|
|
|
|
|
|
| 104 |
return safe
|
| 105 |
|
| 106 |
# ----- NL → SQL heuristic (toy example; edit to your schema) -----
|
|
|
|
| 119 |
)
|
| 120 |
|
| 121 |
# If user typed SQL already, run it as-is
|
| 122 |
+
if re.match(r"^\\s*select ", m):
|
| 123 |
return message
|
| 124 |
|
| 125 |
# Fallback
|