Update tools/sql_tool.py
Browse files- tools/sql_tool.py +69 -28
tools/sql_tool.py
CHANGED
|
@@ -16,68 +16,110 @@ class SQLTool:
|
|
| 16 |
if self.backend == "bigquery":
|
| 17 |
from google.cloud import bigquery
|
| 18 |
from google.oauth2 import service_account
|
|
|
|
| 19 |
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
|
| 20 |
if not key_json:
|
| 21 |
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
|
| 22 |
|
| 23 |
-
# Accept full JSON string from Space Secret
|
| 24 |
info = json.loads(key_json) if key_json.strip().startswith("{") else {}
|
| 25 |
creds = service_account.Credentials.from_service_account_info(info)
|
| 26 |
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
|
| 27 |
|
| 28 |
elif self.backend == "motherduck":
|
| 29 |
import duckdb
|
| 30 |
-
|
|
|
|
| 31 |
if not duckdb.__version__.startswith("1.3.2"):
|
| 32 |
raise RuntimeError(
|
| 33 |
f"Incompatible DuckDB version {duckdb.__version__}. "
|
| 34 |
"Pin duckdb==1.3.2 in requirements.txt and redeploy."
|
| 35 |
)
|
| 36 |
|
| 37 |
-
token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
|
| 38 |
-
db_name =
|
|
|
|
| 39 |
if not token:
|
| 40 |
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
else:
|
| 55 |
-
raise RuntimeError("Unknown SQL backend")
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
"""
|
| 59 |
-
|
| 60 |
-
Edit table/column names to your schema.
|
| 61 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
m = message.lower()
|
| 63 |
|
| 64 |
-
#
|
| 65 |
if "avg" in m and " by " in m:
|
| 66 |
return (
|
| 67 |
-
"-- Example template; edit
|
| 68 |
-
"SELECT DATE_TRUNC('month', date_col) AS month
|
| 69 |
-
"AVG(metric) AS avg_metric
|
| 70 |
-
"FROM analytics.table
|
| 71 |
-
"GROUP BY 1
|
| 72 |
"ORDER BY 1;"
|
| 73 |
)
|
| 74 |
|
| 75 |
-
#
|
| 76 |
if re.match(r"^\s*select ", m):
|
| 77 |
return message
|
| 78 |
|
|
|
|
| 79 |
return "SELECT * FROM analytics.table LIMIT 100;"
|
| 80 |
|
|
|
|
| 81 |
def run(self, message: str) -> pd.DataFrame:
|
| 82 |
sql = self._nl_to_sql(message)
|
| 83 |
try:
|
|
@@ -88,5 +130,4 @@ class SQLTool:
|
|
| 88 |
if self.backend == "bigquery":
|
| 89 |
return self.client.query(sql).to_dataframe()
|
| 90 |
else:
|
| 91 |
-
# DuckDB (MotherDuck)
|
| 92 |
return self.client.execute(sql).fetch_df()
|
|
|
|
| 16 |
if self.backend == "bigquery":
|
| 17 |
from google.cloud import bigquery
|
| 18 |
from google.oauth2 import service_account
|
| 19 |
+
|
| 20 |
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
|
| 21 |
if not key_json:
|
| 22 |
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
|
| 23 |
|
|
|
|
| 24 |
info = json.loads(key_json) if key_json.strip().startswith("{") else {}
|
| 25 |
creds = service_account.Credentials.from_service_account_info(info)
|
| 26 |
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
|
| 27 |
|
| 28 |
elif self.backend == "motherduck":
|
| 29 |
import duckdb
|
| 30 |
+
|
| 31 |
+
# MotherDuck extension compatibility: widely supported ABI is DuckDB 1.3.2
|
| 32 |
if not duckdb.__version__.startswith("1.3.2"):
|
| 33 |
raise RuntimeError(
|
| 34 |
f"Incompatible DuckDB version {duckdb.__version__}. "
|
| 35 |
"Pin duckdb==1.3.2 in requirements.txt and redeploy."
|
| 36 |
)
|
| 37 |
|
| 38 |
+
token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
|
| 39 |
+
db_name = (self.cfg.motherduck_db or "workspace").strip()
|
| 40 |
+
allow_create = (os.getenv("ALLOW_CREATE_DB", "true").lower() == "true")
|
| 41 |
if not token:
|
| 42 |
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
|
| 43 |
|
| 44 |
+
# Primary path: connect directly to the database
|
| 45 |
+
# Correct formats: "md:" (workspace) or "md:<dbname>" (specific DB)
|
| 46 |
+
try:
|
| 47 |
+
uri = f"md:{db_name}?motherduck_token={token}" if db_name and db_name != "workspace" else f"md:?motherduck_token={token}"
|
| 48 |
+
self.client = duckdb.connect(uri)
|
| 49 |
+
# If we connected to workspace explicitly, set DB context if provided
|
| 50 |
+
if db_name and db_name != "workspace":
|
| 51 |
+
# Ensure we are actually in the right DB context
|
| 52 |
+
self._ensure_db_context(db_name, allow_create)
|
| 53 |
+
except Exception as e:
|
| 54 |
+
# Fallback: connect to workspace, then create/use DB if needed
|
| 55 |
+
self.client = duckdb.connect(f"md:?motherduck_token={token}")
|
| 56 |
+
if not db_name or db_name == "workspace":
|
| 57 |
+
# Using workspace only, caller must fully-qualify schema.table in queries
|
| 58 |
+
pass
|
| 59 |
+
else:
|
| 60 |
+
self._ensure_db_context(db_name, allow_create)
|
| 61 |
+
|
| 62 |
else:
|
| 63 |
+
raise RuntimeError(f"Unknown SQL backend: {self.backend}")
|
| 64 |
|
| 65 |
+
# ----- MotherDuck helpers -----
|
| 66 |
+
def _ensure_db_context(self, db_name: str, allow_create: bool):
|
| 67 |
+
"""
|
| 68 |
+
Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
|
| 69 |
+
"""
|
| 70 |
+
# DuckDB/MotherDuck: USE <db_name>;
|
| 71 |
+
try:
|
| 72 |
+
self.client.execute(f"USE {self._quote_ident(db_name)};")
|
| 73 |
+
except Exception as use_err:
|
| 74 |
+
if not allow_create:
|
| 75 |
+
raise RuntimeError(
|
| 76 |
+
f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
|
| 77 |
+
f"Original error: {use_err}"
|
| 78 |
+
)
|
| 79 |
+
# Attempt to create then USE
|
| 80 |
+
try:
|
| 81 |
+
self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
|
| 82 |
+
self.client.execute(f"USE {self._quote_ident(db_name)};")
|
| 83 |
+
except Exception as create_err:
|
| 84 |
+
raise RuntimeError(
|
| 85 |
+
f"Could not create or use database '{db_name}'. "
|
| 86 |
+
f"Original errors: USE: {use_err} | CREATE: {create_err}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def _quote_ident(name: str) -> str:
|
| 91 |
"""
|
| 92 |
+
Very light identifier quoting. Adjust if you allow special chars.
|
|
|
|
| 93 |
"""
|
| 94 |
+
if not name:
|
| 95 |
+
return name
|
| 96 |
+
# basic guard; you can tighten rules for your org naming conventions
|
| 97 |
+
safe = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
| 98 |
+
return safe
|
| 99 |
+
|
| 100 |
+
# ----- NL → SQL heuristic (toy example; edit to your schema) -----
|
| 101 |
+
def _nl_to_sql(self, message: str) -> str:
|
| 102 |
m = message.lower()
|
| 103 |
|
| 104 |
+
# Example DuckDB/MotherDuck flavor of DATE_TRUNC
|
| 105 |
if "avg" in m and " by " in m:
|
| 106 |
return (
|
| 107 |
+
"-- Example template; edit to your schema/columns\n"
|
| 108 |
+
"SELECT DATE_TRUNC('month', date_col) AS month,\n"
|
| 109 |
+
" AVG(metric) AS avg_metric\n"
|
| 110 |
+
"FROM analytics.table\n"
|
| 111 |
+
"GROUP BY 1\n"
|
| 112 |
"ORDER BY 1;"
|
| 113 |
)
|
| 114 |
|
| 115 |
+
# If user typed SQL already, run it as-is
|
| 116 |
if re.match(r"^\s*select ", m):
|
| 117 |
return message
|
| 118 |
|
| 119 |
+
# Fallback
|
| 120 |
return "SELECT * FROM analytics.table LIMIT 100;"
|
| 121 |
|
| 122 |
+
# ----- Execute -----
|
| 123 |
def run(self, message: str) -> pd.DataFrame:
|
| 124 |
sql = self._nl_to_sql(message)
|
| 125 |
try:
|
|
|
|
| 130 |
if self.backend == "bigquery":
|
| 131 |
return self.client.query(sql).to_dataframe()
|
| 132 |
else:
|
|
|
|
| 133 |
return self.client.execute(sql).fetch_df()
|