AshenH commited on
Commit
2336094
·
verified ·
1 Parent(s): 52a979b

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +69 -28
tools/sql_tool.py CHANGED
@@ -16,68 +16,110 @@ class SQLTool:
16
  if self.backend == "bigquery":
17
  from google.cloud import bigquery
18
  from google.oauth2 import service_account
 
19
  key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
20
  if not key_json:
21
  raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
22
 
23
- # Accept full JSON string from Space Secret
24
  info = json.loads(key_json) if key_json.strip().startswith("{") else {}
25
  creds = service_account.Credentials.from_service_account_info(info)
26
  self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
27
 
28
  elif self.backend == "motherduck":
29
  import duckdb
30
- # MotherDuck currently supports DuckDB 1.3.2 broadly across hosts
 
31
  if not duckdb.__version__.startswith("1.3.2"):
32
  raise RuntimeError(
33
  f"Incompatible DuckDB version {duckdb.__version__}. "
34
  "Pin duckdb==1.3.2 in requirements.txt and redeploy."
35
  )
36
 
37
- token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
38
- db_name = my_db
 
39
  if not token:
40
  raise RuntimeError("Missing MOTHERDUCK_TOKEN")
41
 
42
- # Easiest, correct way: connect directly to MotherDuck database.
43
- # This will auto-download/load the extension; no manual INSTALL/LOAD/ATTACH needed.
44
- # Valid URIs include:
45
- # "md:" -> connects to workspace (all DBs)
46
- # f"md:{db_name}" -> connects to a specific DB
47
- # f"md:{db_name}?motherduck_token=..." -> with token in URI
48
- uri = f"md:{db_name}?motherduck_token={token}"
49
- self.client = duckdb.connect(uri)
50
-
51
- # Optional: set a default database context (USE) if you connected to 'md:' (workspace)
52
- # if db_name in ("", "workspace"):
53
- # self.client.execute("USE your_database;")
 
 
 
 
 
 
54
  else:
55
- raise RuntimeError("Unknown SQL backend")
56
 
57
- def _nl_to_sql(self, message: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  """
59
- Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
60
- Edit table/column names to your schema.
61
  """
 
 
 
 
 
 
 
 
62
  m = message.lower()
63
 
64
- # Simple example (DuckDB/MotherDuck DATE_TRUNC flavor)
65
  if "avg" in m and " by " in m:
66
  return (
67
- "-- Example template; edit me\n"
68
- "SELECT DATE_TRUNC('month', date_col) AS month, "
69
- "AVG(metric) AS avg_metric "
70
- "FROM analytics.table "
71
- "GROUP BY 1 "
72
  "ORDER BY 1;"
73
  )
74
 
75
- # Pass-through if the user typed SQL explicitly
76
  if re.match(r"^\s*select ", m):
77
  return message
78
 
 
79
  return "SELECT * FROM analytics.table LIMIT 100;"
80
 
 
81
  def run(self, message: str) -> pd.DataFrame:
82
  sql = self._nl_to_sql(message)
83
  try:
@@ -88,5 +130,4 @@ class SQLTool:
88
  if self.backend == "bigquery":
89
  return self.client.query(sql).to_dataframe()
90
  else:
91
- # DuckDB (MotherDuck)
92
  return self.client.execute(sql).fetch_df()
 
16
  if self.backend == "bigquery":
17
  from google.cloud import bigquery
18
  from google.oauth2 import service_account
19
+
20
  key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
21
  if not key_json:
22
  raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
23
 
 
24
  info = json.loads(key_json) if key_json.strip().startswith("{") else {}
25
  creds = service_account.Credentials.from_service_account_info(info)
26
  self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
27
 
28
  elif self.backend == "motherduck":
29
  import duckdb
30
+
31
+ # MotherDuck extension compatibility: widely supported ABI is DuckDB 1.3.2
32
  if not duckdb.__version__.startswith("1.3.2"):
33
  raise RuntimeError(
34
  f"Incompatible DuckDB version {duckdb.__version__}. "
35
  "Pin duckdb==1.3.2 in requirements.txt and redeploy."
36
  )
37
 
38
+ token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
39
+ db_name = (self.cfg.motherduck_db or "workspace").strip()
40
+ allow_create = (os.getenv("ALLOW_CREATE_DB", "true").lower() == "true")
41
  if not token:
42
  raise RuntimeError("Missing MOTHERDUCK_TOKEN")
43
 
44
+ # Primary path: connect directly to the database
45
+ # Correct formats: "md:" (workspace) or "md:<dbname>" (specific DB)
46
+ try:
47
+ uri = f"md:{db_name}?motherduck_token={token}" if db_name and db_name != "workspace" else f"md:?motherduck_token={token}"
48
+ self.client = duckdb.connect(uri)
49
+ # If we connected to workspace explicitly, set DB context if provided
50
+ if db_name and db_name != "workspace":
51
+ # Ensure we are actually in the right DB context
52
+ self._ensure_db_context(db_name, allow_create)
53
+ except Exception as e:
54
+ # Fallback: connect to workspace, then create/use DB if needed
55
+ self.client = duckdb.connect(f"md:?motherduck_token={token}")
56
+ if not db_name or db_name == "workspace":
57
+ # Using workspace only, caller must fully-qualify schema.table in queries
58
+ pass
59
+ else:
60
+ self._ensure_db_context(db_name, allow_create)
61
+
62
  else:
63
+ raise RuntimeError(f"Unknown SQL backend: {self.backend}")
64
 
65
+ # ----- MotherDuck helpers -----
66
+ def _ensure_db_context(self, db_name: str, allow_create: bool):
67
+ """
68
+ Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
69
+ """
70
+ # DuckDB/MotherDuck: USE <db_name>;
71
+ try:
72
+ self.client.execute(f"USE {self._quote_ident(db_name)};")
73
+ except Exception as use_err:
74
+ if not allow_create:
75
+ raise RuntimeError(
76
+ f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
77
+ f"Original error: {use_err}"
78
+ )
79
+ # Attempt to create then USE
80
+ try:
81
+ self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
82
+ self.client.execute(f"USE {self._quote_ident(db_name)};")
83
+ except Exception as create_err:
84
+ raise RuntimeError(
85
+ f"Could not create or use database '{db_name}'. "
86
+ f"Original errors: USE: {use_err} | CREATE: {create_err}"
87
+ )
88
+
89
+ @staticmethod
90
+ def _quote_ident(name: str) -> str:
91
  """
92
+ Very light identifier quoting. Adjust if you allow special chars.
 
93
  """
94
+ if not name:
95
+ return name
96
+ # basic guard; you can tighten rules for your org naming conventions
97
+ safe = re.sub(r"[^a-zA-Z0-9_]", "_", name)
98
+ return safe
99
+
100
+ # ----- NL → SQL heuristic (toy example; edit to your schema) -----
101
+ def _nl_to_sql(self, message: str) -> str:
102
  m = message.lower()
103
 
104
+ # Example DuckDB/MotherDuck flavor of DATE_TRUNC
105
  if "avg" in m and " by " in m:
106
  return (
107
+ "-- Example template; edit to your schema/columns\n"
108
+ "SELECT DATE_TRUNC('month', date_col) AS month,\n"
109
+ " AVG(metric) AS avg_metric\n"
110
+ "FROM analytics.table\n"
111
+ "GROUP BY 1\n"
112
  "ORDER BY 1;"
113
  )
114
 
115
+ # If user typed SQL already, run it as-is
116
  if re.match(r"^\s*select ", m):
117
  return message
118
 
119
+ # Fallback
120
  return "SELECT * FROM analytics.table LIMIT 100;"
121
 
122
+ # ----- Execute -----
123
  def run(self, message: str) -> pd.DataFrame:
124
  sql = self._nl_to_sql(message)
125
  try:
 
130
  if self.backend == "bigquery":
131
  return self.client.query(sql).to_dataframe()
132
  else:
 
133
  return self.client.execute(sql).fetch_df()