ALM_LLM / tools /sql_tool.py
AshenH's picture
Update tools/sql_tool.py
54614e9 verified
raw
history blame
3.65 kB
# space/tools/sql_tool.py
import os
import re
import json
import pandas as pd
from utils.config import AppConfig
from utils.tracing import Tracer
class SQLTool:
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self.backend = cfg.sql_backend # "bigquery" or "motherduck"
if self.backend == "bigquery":
from google.cloud import bigquery
from google.oauth2 import service_account
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
if not key_json:
raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
# Accept full JSON string from Space Secret
info = json.loads(key_json) if key_json.strip().startswith("{") else {}
creds = service_account.Credentials.from_service_account_info(info)
self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
elif self.backend == "motherduck":
import duckdb
# MotherDuck currently supports DuckDB 1.3.2 broadly across hosts
if not duckdb.__version__.startswith("1.3.2"):
raise RuntimeError(
f"Incompatible DuckDB version {duckdb.__version__}. "
"Pin duckdb==1.3.2 in requirements.txt and redeploy."
)
token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
db_name = (self.cfg.motherduck_db or "workspace").strip()
if not token:
raise RuntimeError("Missing MOTHERDUCK_TOKEN")
# Easiest, correct way: connect directly to MotherDuck database.
# This will auto-download/load the extension; no manual INSTALL/LOAD/ATTACH needed.
# Valid URIs include:
# "md:" -> connects to workspace (all DBs)
# f"md:{db_name}" -> connects to a specific DB
# f"md:{db_name}?motherduck_token=..." -> with token in URI
uri = f"md:{db_name}?motherduck_token={token}"
self.client = duckdb.connect(uri)
# Optional: set a default database context (USE) if you connected to 'md:' (workspace)
# if db_name in ("", "workspace"):
# self.client.execute("USE your_database;")
else:
raise RuntimeError("Unknown SQL backend")
def _nl_to_sql(self, message: str) -> str:
"""
Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
Edit table/column names to your schema.
"""
m = message.lower()
# Simple example (DuckDB/MotherDuck DATE_TRUNC flavor)
if "avg" in m and " by " in m:
return (
"-- Example template; edit me\n"
"SELECT DATE_TRUNC('month', date_col) AS month, "
"AVG(metric) AS avg_metric "
"FROM analytics.table "
"GROUP BY 1 "
"ORDER BY 1;"
)
# Pass-through if the user typed SQL explicitly
if re.match(r"^\s*select ", m):
return message
return "SELECT * FROM analytics.table LIMIT 100;"
def run(self, message: str) -> pd.DataFrame:
sql = self._nl_to_sql(message)
try:
self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
except Exception:
pass
if self.backend == "bigquery":
return self.client.query(sql).to_dataframe()
else:
# DuckDB (MotherDuck)
return self.client.execute(sql).fetch_df()