AshenH commited on
Commit
54614e9
·
verified ·
1 Parent(s): 23dc37a

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +18 -38
tools/sql_tool.py CHANGED
@@ -2,11 +2,7 @@
2
  import os
3
  import re
4
  import json
5
- import shutil
6
- import glob
7
  import pandas as pd
8
- from typing import Optional
9
-
10
  from utils.config import AppConfig
11
  from utils.tracing import Tracer
12
 
@@ -20,7 +16,6 @@ class SQLTool:
20
  if self.backend == "bigquery":
21
  from google.cloud import bigquery
22
  from google.oauth2 import service_account
23
-
24
  key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
25
  if not key_json:
26
  raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
@@ -32,54 +27,42 @@ class SQLTool:
32
 
33
  elif self.backend == "motherduck":
34
  import duckdb
35
-
36
- # ---- Enforce supported DuckDB version for MotherDuck extension ----
37
  if not duckdb.__version__.startswith("1.3.2"):
38
  raise RuntimeError(
39
  f"Incompatible DuckDB version {duckdb.__version__}. "
40
- "MotherDuck currently supports DuckDB 1.3.2. "
41
  "Pin duckdb==1.3.2 in requirements.txt and redeploy."
42
  )
43
 
44
  token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
45
- db_name = self.cfg.motherduck_db or "default"
46
  if not token:
47
  raise RuntimeError("Missing MOTHERDUCK_TOKEN")
48
 
49
- # ---- Clean stale extension caches compiled for other DuckDB versions ----
50
- try:
51
- ext_root = os.path.expanduser("~/.duckdb/extensions")
52
- for p in glob.glob(os.path.join(ext_root, "*")):
53
- if "1.3.2" not in p: # keep only current version caches
54
- shutil.rmtree(p, ignore_errors=True)
55
- except Exception:
56
- # best-effort cleanup; proceed even if it fails
57
- pass
58
-
59
- # ---- Connect & load MotherDuck extension ----
60
- self.client = duckdb.connect() # in-memory connection; we'll ATTACH MotherDuck
61
- self.client.execute("INSTALL motherduck;")
62
- self.client.execute("LOAD motherduck;")
63
-
64
- # Attach the remote MotherDuck database and use it
65
- self.client.execute(f"SET motherduck_token='{token}';")
66
- self.client.execute(f"ATTACH 'md:/{db_name}' AS md;")
67
- self.client.execute("USE md;") # subsequent queries run against 'md' by default
68
  else:
69
  raise RuntimeError("Unknown SQL backend")
70
 
71
  def _nl_to_sql(self, message: str) -> str:
72
  """
73
  Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
74
- Expect users to include table names. Example:
75
- "avg metric by month from analytics.events"
76
  """
77
  m = message.lower()
78
 
79
- # Very basic template example (edit table/columns to your schema)
80
  if "avg" in m and " by " in m:
81
- # DuckDB uses DATE_TRUNC('month', col); BigQuery uses DATE_TRUNC(col, MONTH).
82
- # This generic SQL should work in DuckDB/MotherDuck; adapt if using BigQuery.
83
  return (
84
  "-- Example template; edit me\n"
85
  "SELECT DATE_TRUNC('month', date_col) AS month, "
@@ -93,7 +76,6 @@ class SQLTool:
93
  if re.match(r"^\s*select ", m):
94
  return message
95
 
96
- # Fallback
97
  return "SELECT * FROM analytics.table LIMIT 100;"
98
 
99
  def run(self, message: str) -> pd.DataFrame:
@@ -104,9 +86,7 @@ class SQLTool:
104
  pass
105
 
106
  if self.backend == "bigquery":
107
- df = self.client.query(sql).to_dataframe()
108
  else:
109
  # DuckDB (MotherDuck)
110
- df = self.client.execute(sql).fetch_df()
111
-
112
- return df
 
2
  import os
3
  import re
4
  import json
 
 
5
  import pandas as pd
 
 
6
  from utils.config import AppConfig
7
  from utils.tracing import Tracer
8
 
 
16
  if self.backend == "bigquery":
17
  from google.cloud import bigquery
18
  from google.oauth2 import service_account
 
19
  key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
20
  if not key_json:
21
  raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
 
27
 
28
  elif self.backend == "motherduck":
29
  import duckdb
30
+ # MotherDuck currently supports DuckDB 1.3.2 broadly across hosts
 
31
  if not duckdb.__version__.startswith("1.3.2"):
32
  raise RuntimeError(
33
  f"Incompatible DuckDB version {duckdb.__version__}. "
 
34
  "Pin duckdb==1.3.2 in requirements.txt and redeploy."
35
  )
36
 
37
  token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
38
+ db_name = (self.cfg.motherduck_db or "workspace").strip()
39
  if not token:
40
  raise RuntimeError("Missing MOTHERDUCK_TOKEN")
41
 
42
+ # Easiest, correct way: connect directly to MotherDuck database.
43
+ # This will auto-download/load the extension; no manual INSTALL/LOAD/ATTACH needed.
44
+ # Valid URIs include:
45
+ # "md:" -> connects to workspace (all DBs)
46
+ # f"md:{db_name}" -> connects to a specific DB
47
+ # f"md:{db_name}?motherduck_token=..." -> with token in URI
48
+ uri = f"md:{db_name}?motherduck_token={token}"
49
+ self.client = duckdb.connect(uri)
50
+
51
+ # Optional: set a default database context (USE) if you connected to 'md:' (workspace)
52
+ # if db_name in ("", "workspace"):
53
+ # self.client.execute("USE your_database;")
 
 
 
 
 
 
 
54
  else:
55
  raise RuntimeError("Unknown SQL backend")
56
 
57
  def _nl_to_sql(self, message: str) -> str:
58
  """
59
  Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
60
+ Edit table/column names to your schema.
 
61
  """
62
  m = message.lower()
63
 
64
+ # Simple example (DuckDB/MotherDuck DATE_TRUNC flavor)
65
  if "avg" in m and " by " in m:
 
 
66
  return (
67
  "-- Example template; edit me\n"
68
  "SELECT DATE_TRUNC('month', date_col) AS month, "
 
76
  if re.match(r"^\s*select ", m):
77
  return message
78
 
 
79
  return "SELECT * FROM analytics.table LIMIT 100;"
80
 
81
  def run(self, message: str) -> pd.DataFrame:
 
86
  pass
87
 
88
  if self.backend == "bigquery":
89
+ return self.client.query(sql).to_dataframe()
90
  else:
91
  # DuckDB (MotherDuck)
92
+ return self.client.execute(sql).fetch_df()