AshenH commited on
Commit
85b8a4e
·
verified ·
1 Parent(s): 2336094

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +37 -31
tools/sql_tool.py CHANGED
@@ -7,12 +7,16 @@ from utils.config import AppConfig
7
  from utils.tracing import Tracer
8
 
9
 
 
 
 
10
  class SQLTool:
11
  def __init__(self, cfg: AppConfig, tracer: Tracer):
12
  self.cfg = cfg
13
  self.tracer = tracer
14
  self.backend = cfg.sql_backend # "bigquery" or "motherduck"
15
 
 
16
  if self.backend == "bigquery":
17
  from google.cloud import bigquery
18
  from google.oauth2 import service_account
@@ -25,6 +29,7 @@ class SQLTool:
25
  creds = service_account.Credentials.from_service_account_info(info)
26
  self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
27
 
 
28
  elif self.backend == "motherduck":
29
  import duckdb
30
 
@@ -41,22 +46,18 @@ class SQLTool:
41
  if not token:
42
  raise RuntimeError("Missing MOTHERDUCK_TOKEN")
43
 
44
- # Primary path: connect directly to the database
45
- # Correct formats: "md:" (workspace) or "md:<dbname>" (specific DB)
46
- try:
47
- uri = f"md:{db_name}?motherduck_token={token}" if db_name and db_name != "workspace" else f"md:?motherduck_token={token}"
48
- self.client = duckdb.connect(uri)
49
- # If we connected to workspace explicitly, set DB context if provided
50
- if db_name and db_name != "workspace":
51
- # Ensure we are actually in the right DB context
52
- self._ensure_db_context(db_name, allow_create)
53
- except Exception as e:
54
- # Fallback: connect to workspace, then create/use DB if needed
55
  self.client = duckdb.connect(f"md:?motherduck_token={token}")
56
- if not db_name or db_name == "workspace":
57
- # Using workspace only, caller must fully-qualify schema.table in queries
58
- pass
59
- else:
 
 
 
 
60
  self._ensure_db_context(db_name, allow_create)
61
 
62
  else:
@@ -66,35 +67,40 @@ class SQLTool:
66
  def _ensure_db_context(self, db_name: str, allow_create: bool):
67
  """
68
  Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
 
69
  """
70
- # DuckDB/MotherDuck: USE <db_name>;
 
 
 
 
71
  try:
72
  self.client.execute(f"USE {self._quote_ident(db_name)};")
 
73
  except Exception as use_err:
74
  if not allow_create:
75
  raise RuntimeError(
76
  f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
77
  f"Original error: {use_err}"
78
  )
79
- # Attempt to create then USE
80
- try:
81
- self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
82
- self.client.execute(f"USE {self._quote_ident(db_name)};")
83
- except Exception as create_err:
84
- raise RuntimeError(
85
- f"Could not create or use database '{db_name}'. "
86
- f"Original errors: USE: {use_err} | CREATE: {create_err}"
87
- )
 
 
88
 
89
  @staticmethod
90
  def _quote_ident(name: str) -> str:
91
  """
92
- Very light identifier quoting. Adjust if you allow special chars.
93
  """
94
- if not name:
95
- return name
96
- # basic guard; you can tighten rules for your org naming conventions
97
- safe = re.sub(r"[^a-zA-Z0-9_]", "_", name)
98
  return safe
99
 
100
  # ----- NL → SQL heuristic (toy example; edit to your schema) -----
@@ -113,7 +119,7 @@ class SQLTool:
113
  )
114
 
115
  # If user typed SQL already, run it as-is
116
- if re.match(r"^\s*select ", m):
117
  return message
118
 
119
  # Fallback
 
7
  from utils.tracing import Tracer
8
 
9
 
10
+ RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"} # treat these as workspace/no-DB context
11
+
12
+
13
  class SQLTool:
14
  def __init__(self, cfg: AppConfig, tracer: Tracer):
15
  self.cfg = cfg
16
  self.tracer = tracer
17
  self.backend = cfg.sql_backend # "bigquery" or "motherduck"
18
 
19
+ # ---------------- BIGQUERY BACKEND ----------------
20
  if self.backend == "bigquery":
21
  from google.cloud import bigquery
22
  from google.oauth2 import service_account
 
29
  creds = service_account.Credentials.from_service_account_info(info)
30
  self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
31
 
32
+ # ---------------- MOTHERDUCK BACKEND ----------------
33
  elif self.backend == "motherduck":
34
  import duckdb
35
 
 
46
  if not token:
47
  raise RuntimeError("Missing MOTHERDUCK_TOKEN")
48
 
49
+ # Workspace vs concrete DB handling
50
+ if db_name in RESERVED_MD_WORKSPACE_NAMES:
51
+ # Connect to workspace; caller should fully-qualify tables if needed
 
 
 
 
 
 
 
 
52
  self.client = duckdb.connect(f"md:?motherduck_token={token}")
53
+ # No USE/CREATE in workspace mode
54
+ else:
55
+ # Try direct connection to the DB (preferred)
56
+ try:
57
+ self.client = duckdb.connect(f"md:{db_name}?motherduck_token={token}")
58
+ except Exception:
59
+ # Fallback: connect to workspace, then USE/CREATE the DB if permitted
60
+ self.client = duckdb.connect(f"md:?motherduck_token={token}")
61
  self._ensure_db_context(db_name, allow_create)
62
 
63
  else:
 
67
  def _ensure_db_context(self, db_name: str, allow_create: bool):
68
  """
69
  Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
70
+ Skips reserved workspace names.
71
  """
72
+ if db_name in RESERVED_MD_WORKSPACE_NAMES:
73
+ # No-op for workspace/default
74
+ return
75
+
76
+ # Attempt USE first
77
  try:
78
  self.client.execute(f"USE {self._quote_ident(db_name)};")
79
+ return
80
  except Exception as use_err:
81
  if not allow_create:
82
  raise RuntimeError(
83
  f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
84
  f"Original error: {use_err}"
85
  )
86
+
87
+ # Attempt CREATE then USE
88
+ try:
89
+ # CREATE DATABASE <name>; is supported on MotherDuck for valid names (not 'default')
90
+ self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
91
+ self.client.execute(f"USE {self._quote_ident(db_name)};")
92
+ except Exception as create_err:
93
+ raise RuntimeError(
94
+ f"Could not create or use database '{db_name}'. "
95
+ f"Original errors: CREATE: {create_err}"
96
+ )
97
 
98
  @staticmethod
99
  def _quote_ident(name: str) -> str:
100
  """
101
+ Very light identifier quoting. Replace non [a-zA-Z0-9_] with underscore.
102
  """
103
+ safe = re.sub(r"[^a-zA-Z0-9_]", "_", (name or ""))
 
 
 
104
  return safe
105
 
106
  # ----- NL → SQL heuristic (toy example; edit to your schema) -----
 
119
  )
120
 
121
  # If user typed SQL already, run it as-is
122
+ if re.match(r"^\\s*select ", m):
123
  return message
124
 
125
  # Fallback