AshenH commited on
Commit
6c6d38f
·
verified ·
1 Parent(s): da25b2a

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +338 -90
tools/sql_tool.py CHANGED
@@ -2,138 +2,386 @@
2
  import os
3
  import re
4
  import json
 
5
  import pandas as pd
 
6
  from utils.config import AppConfig
7
  from utils.tracing import Tracer
8
 
 
9
 
10
- RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"} # treat these as workspace/no-DB context
 
 
 
 
 
 
 
11
 
12
 
13
  class SQLTool:
 
 
 
 
 
14
  def __init__(self, cfg: AppConfig, tracer: Tracer):
15
  self.cfg = cfg
16
  self.tracer = tracer
17
- self.backend = cfg.sql_backend # "bigquery" or "motherduck"
18
-
19
- # ---------------- BIGQUERY BACKEND ----------------
20
- if self.backend == "bigquery":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from google.cloud import bigquery
22
  from google.oauth2 import service_account
23
-
24
  key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
25
  if not key_json:
26
- raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
27
-
28
- info = json.loads(key_json) if key_json.strip().startswith("{") else {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  creds = service_account.Credentials.from_service_account_info(info)
30
- self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
31
-
32
- # ---------------- MOTHERDUCK BACKEND ----------------
33
- elif self.backend == "motherduck":
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  import duckdb
35
-
36
- # MotherDuck extension compatibility: widely supported ABI is DuckDB 1.3.2
37
- if not duckdb.__version__.startswith("1.3.2"):
38
- raise RuntimeError(
39
- f"Incompatible DuckDB version {duckdb.__version__}. "
40
- "Pin duckdb==1.3.2 in requirements.txt and redeploy."
 
 
 
 
41
  )
42
-
 
43
  token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
44
- db_name = (self.cfg.motherduck_db or "workspace").strip()
45
- allow_create = (os.getenv("ALLOW_CREATE_DB", "true").lower() == "true")
46
  if not token:
47
- raise RuntimeError("Missing MOTHERDUCK_TOKEN")
48
-
49
- # Workspace vs concrete DB handling
 
 
 
 
 
 
50
  if db_name in RESERVED_MD_WORKSPACE_NAMES:
51
- # Connect to workspace; caller should fully-qualify tables if needed
52
- self.client = duckdb.connect(f"md:?motherduck_token={token}")
53
- # No USE/CREATE in workspace mode
 
54
  else:
55
- # Try direct connection to the DB (preferred)
56
  try:
57
- self.client = duckdb.connect(f"md:{db_name}?motherduck_token={token}")
58
- except Exception:
59
- # Fallback: connect to workspace, then USE/CREATE the DB if permitted
60
- self.client = duckdb.connect(f"md:?motherduck_token={token}")
 
 
 
 
 
61
  self._ensure_db_context(db_name, allow_create)
62
-
63
- else:
64
- raise RuntimeError(f"Unknown SQL backend: {self.backend}")
65
-
66
- # ----- MotherDuck helpers -----
 
 
 
 
 
 
 
 
67
  def _ensure_db_context(self, db_name: str, allow_create: bool):
68
  """
69
- Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
70
- Skips reserved workspace names.
71
  """
72
  if db_name in RESERVED_MD_WORKSPACE_NAMES:
73
- # No-op for workspace/default
74
  return
75
-
76
- # Attempt USE first
 
 
77
  try:
78
- self.client.execute(f"USE {self._quote_ident(db_name)};")
 
79
  return
80
  except Exception as use_err:
 
 
81
  if not allow_create:
82
- raise RuntimeError(
83
- f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
84
- f"Original error: {use_err}"
85
  )
86
-
87
- # Attempt CREATE then USE
88
  try:
89
- # CREATE DATABASE <name>; is supported on MotherDuck for valid names (not 'default')
90
- self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
91
- self.client.execute(f"USE {self._quote_ident(db_name)};")
 
92
  except Exception as create_err:
93
- raise RuntimeError(
94
- f"Could not create or use database '{db_name}'. "
95
- f"Original errors: CREATE: {create_err}"
96
- )
97
-
98
  @staticmethod
99
  def _quote_ident(name: str) -> str:
100
  """
101
- Very light identifier quoting. Replace non [a-zA-Z0-9_] with underscore.
 
102
  """
103
- safe = re.sub(r"[^a-zA-Z0-9_]", "_", (name or ""))
 
 
 
 
 
 
 
 
 
104
  return safe
105
-
106
- # ----- NL → SQL heuristic (toy example; edit to your schema) -----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def _nl_to_sql(self, message: str) -> str:
 
 
 
 
108
  m = message.lower()
109
-
110
- # Example DuckDB/MotherDuck flavor of DATE_TRUNC
111
- if "avg" in m and " by " in m:
112
- return (
113
- "-- Example template; edit to your schema/columns\n"
114
- "SELECT DATE_TRUNC('month', date_col) AS month,\n"
115
- " AVG(metric) AS avg_metric\n"
116
- "FROM analytics.table\n"
117
- "GROUP BY 1\n"
118
- "ORDER BY 1;"
119
- )
120
-
121
- # If user typed SQL already, run it as-is
122
- if re.match(r"^\\s*select ", m):
123
- return message
124
-
125
- # Fallback
126
- return "SELECT * FROM analytics.table LIMIT 100;"
127
-
128
- # ----- Execute -----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def run(self, message: str) -> pd.DataFrame:
130
- sql = self._nl_to_sql(message)
 
 
 
 
 
 
 
 
 
 
 
131
  try:
132
- self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
133
- except Exception:
134
- pass
135
-
136
- if self.backend == "bigquery":
137
- return self.client.query(sql).to_dataframe()
138
- else:
139
- return self.client.execute(sql).fetch_df()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import re
4
  import json
5
+ import logging
6
  import pandas as pd
7
+ from typing import Optional
8
  from utils.config import AppConfig
9
  from utils.tracing import Tracer
10
 
11
+ logger = logging.getLogger(__name__)
12
 
13
+ RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"}
14
+ MAX_QUERY_LENGTH = 50000
15
+ MAX_RESULT_ROWS = 100000
16
+
17
+
18
+ class SQLToolError(Exception):
19
+ """Custom exception for SQL tool errors."""
20
+ pass
21
 
22
 
23
  class SQLTool:
24
+ """
25
+ SQL execution tool supporting BigQuery and MotherDuck backends.
26
+ Includes input validation, error handling, and secure query execution.
27
+ """
28
+
29
  def __init__(self, cfg: AppConfig, tracer: Tracer):
30
  self.cfg = cfg
31
  self.tracer = tracer
32
+ self.backend = cfg.sql_backend
33
+ self.client = None
34
+
35
+ logger.info(f"Initializing SQLTool with backend: {self.backend}")
36
+
37
+ try:
38
+ if self.backend == "bigquery":
39
+ self._init_bigquery()
40
+ elif self.backend == "motherduck":
41
+ self._init_motherduck()
42
+ else:
43
+ raise SQLToolError(f"Unknown SQL backend: {self.backend}")
44
+
45
+ logger.info(f"SQLTool initialized successfully with {self.backend}")
46
+
47
+ except Exception as e:
48
+ logger.error(f"Failed to initialize SQLTool: {e}")
49
+ raise SQLToolError(f"SQL backend initialization failed: {e}") from e
50
+
51
+ def _init_bigquery(self):
52
+ """Initialize BigQuery client with service account credentials."""
53
+ try:
54
  from google.cloud import bigquery
55
  from google.oauth2 import service_account
56
+
57
  key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
58
  if not key_json:
59
+ raise SQLToolError(
60
+ "Missing GCP_SERVICE_ACCOUNT_JSON environment variable. "
61
+ "Please configure BigQuery credentials."
62
+ )
63
+
64
+ # Parse credentials
65
+ try:
66
+ if key_json.strip().startswith("{"):
67
+ info = json.loads(key_json)
68
+ else:
69
+ # Assume it's a file path
70
+ with open(key_json, 'r') as f:
71
+ info = json.load(f)
72
+ except json.JSONDecodeError as e:
73
+ raise SQLToolError(f"Invalid JSON in GCP_SERVICE_ACCOUNT_JSON: {e}")
74
+ except FileNotFoundError:
75
+ raise SQLToolError(f"GCP service account file not found: {key_json}")
76
+
77
+ # Validate required fields
78
+ required_fields = ["type", "project_id", "private_key", "client_email"]
79
+ missing = [f for f in required_fields if f not in info]
80
+ if missing:
81
+ raise SQLToolError(
82
+ f"GCP service account JSON missing required fields: {missing}"
83
+ )
84
+
85
  creds = service_account.Credentials.from_service_account_info(info)
86
+ project = self.cfg.gcp_project or info.get("project_id")
87
+
88
+ if not project:
89
+ raise SQLToolError("GCP project ID not specified in config or credentials")
90
+
91
+ self.client = bigquery.Client(credentials=creds, project=project)
92
+ logger.info(f"BigQuery client initialized for project: {project}")
93
+
94
+ except ImportError as e:
95
+ raise SQLToolError(
96
+ "BigQuery dependencies not installed. "
97
+ "Install with: pip install google-cloud-bigquery"
98
+ ) from e
99
+
100
+ def _init_motherduck(self):
101
+ """Initialize MotherDuck/DuckDB client with version validation."""
102
+ try:
103
  import duckdb
104
+
105
+ # Version compatibility check - be more flexible
106
+ version = duckdb.__version__
107
+ logger.info(f"DuckDB version: {version}")
108
+
109
+ # Warn if not on recommended version, but don't fail
110
+ if not version.startswith("1.3"):
111
+ logger.warning(
112
+ f"DuckDB {version} detected. Recommended: 1.3.x for MotherDuck compatibility. "
113
+ "Some features may not work as expected."
114
  )
115
+
116
+ # Get configuration
117
  token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
 
 
118
  if not token:
119
+ raise SQLToolError(
120
+ "Missing MOTHERDUCK_TOKEN. "
121
+ "Get your token from: https://motherduck.com/docs/key-tasks/authenticating-to-motherduck"
122
+ )
123
+
124
+ db_name = (self.cfg.motherduck_db or "workspace").strip()
125
+ allow_create = os.getenv("ALLOW_CREATE_DB", "true").lower() == "true"
126
+
127
+ # Connect based on database name
128
  if db_name in RESERVED_MD_WORKSPACE_NAMES:
129
+ # Workspace mode - no specific database context
130
+ connection_string = f"md:?motherduck_token={token}"
131
+ logger.info("Connecting to MotherDuck workspace")
132
+ self.client = duckdb.connect(connection_string)
133
  else:
134
+ # Try connecting to specific database
135
  try:
136
+ connection_string = f"md:{db_name}?motherduck_token={token}"
137
+ logger.info(f"Connecting to MotherDuck database: {db_name}")
138
+ self.client = duckdb.connect(connection_string)
139
+ except Exception as db_err:
140
+ logger.warning(f"Direct connection to '{db_name}' failed: {db_err}")
141
+
142
+ # Fallback: connect to workspace and setup database
143
+ connection_string = f"md:?motherduck_token={token}"
144
+ self.client = duckdb.connect(connection_string)
145
  self._ensure_db_context(db_name, allow_create)
146
+
147
+ # Test connection
148
+ try:
149
+ self.client.execute("SELECT 1").fetchone()
150
+ logger.info("MotherDuck connection test successful")
151
+ except Exception as e:
152
+ raise SQLToolError(f"MotherDuck connection test failed: {e}")
153
+
154
+ except ImportError as e:
155
+ raise SQLToolError(
156
+ "DuckDB not installed. Install with: pip install duckdb"
157
+ ) from e
158
+
159
  def _ensure_db_context(self, db_name: str, allow_create: bool):
160
  """
161
+ Ensure database context is set for MotherDuck.
162
+ Creates database if it doesn't exist and allow_create is True.
163
  """
164
  if db_name in RESERVED_MD_WORKSPACE_NAMES:
 
165
  return
166
+
167
+ safe_name = self._quote_ident(db_name)
168
+
169
+ # Try to USE the database first
170
  try:
171
+ self.client.execute(f"USE {safe_name};")
172
+ logger.info(f"Using existing database: {db_name}")
173
  return
174
  except Exception as use_err:
175
+ logger.info(f"Database '{db_name}' not found: {use_err}")
176
+
177
  if not allow_create:
178
+ raise SQLToolError(
179
+ f"Database '{db_name}' does not exist and ALLOW_CREATE_DB is disabled. "
180
+ f"Either create the database manually or set ALLOW_CREATE_DB=true"
181
  )
182
+
183
+ # Attempt to create and use the database
184
  try:
185
+ logger.info(f"Creating database: {db_name}")
186
+ self.client.execute(f"CREATE DATABASE IF NOT EXISTS {safe_name};")
187
+ self.client.execute(f"USE {safe_name};")
188
+ logger.info(f"Database '{db_name}' created and selected")
189
  except Exception as create_err:
190
+ raise SQLToolError(
191
+ f"Failed to create database '{db_name}': {create_err}"
192
+ ) from create_err
193
+
 
194
  @staticmethod
195
  def _quote_ident(name: str) -> str:
196
  """
197
+ Safely quote SQL identifiers.
198
+ Replaces non-alphanumeric characters with underscores.
199
  """
200
+ if not name:
201
+ return "unnamed"
202
+
203
+ # Remove dangerous characters
204
+ safe = re.sub(r"[^a-zA-Z0-9_]", "_", name)
205
+
206
+ # Ensure it doesn't start with a number
207
+ if safe[0].isdigit():
208
+ safe = "_" + safe
209
+
210
  return safe
211
+
212
+ def _validate_sql(self, sql: str) -> tuple[bool, str]:
213
+ """
214
+ Validate SQL query for basic safety.
215
+ Returns (is_valid, error_message).
216
+ """
217
+ if not sql or not sql.strip():
218
+ return False, "Empty SQL query"
219
+
220
+ if len(sql) > MAX_QUERY_LENGTH:
221
+ return False, f"Query too long (max {MAX_QUERY_LENGTH} characters)"
222
+
223
+ # Dangerous patterns check
224
+ sql_lower = sql.lower()
225
+
226
+ # Block multiple statements (simple check)
227
+ if sql.count(';') > 1:
228
+ return False, "Multiple SQL statements not allowed"
229
+
230
+ # Block dangerous keywords in non-SELECT queries
231
+ dangerous_patterns = [
232
+ (r'\bdrop\s+table\b', "DROP TABLE"),
233
+ (r'\bdrop\s+database\b', "DROP DATABASE"),
234
+ (r'\bdelete\s+from\b', "DELETE FROM"),
235
+ (r'\btruncate\b', "TRUNCATE"),
236
+ (r'\bexec\s*\(', "EXEC"),
237
+ (r'\bexecute\s*\(', "EXECUTE"),
238
+ ]
239
+
240
+ for pattern, name in dangerous_patterns:
241
+ if re.search(pattern, sql_lower):
242
+ logger.warning(f"Blocked query with {name} pattern")
243
+ return False, f"Query contains blocked operation: {name}"
244
+
245
+ return True, ""
246
+
247
  def _nl_to_sql(self, message: str) -> str:
248
+ """
249
+ Convert natural language to SQL query.
250
+ This is a simple heuristic - replace with proper NL2SQL model for production.
251
+ """
252
  m = message.lower()
253
+
254
+ # If it's already SQL, return as-is (after validation)
255
+ if re.match(r'^\s*select\s', m, re.IGNORECASE):
256
+ return message.strip()
257
+
258
+ # Template-based generation (customize for your schema)
259
+ if "avg" in m or "average" in m:
260
+ if "by month" in m or "monthly" in m:
261
+ return """
262
+ SELECT
263
+ DATE_TRUNC('month', date_col) AS month,
264
+ AVG(metric_col) AS avg_metric
265
+ FROM analytics.fact_table
266
+ GROUP BY 1
267
+ ORDER BY 1 DESC
268
+ LIMIT 100;
269
+ """
270
+
271
+ if "top" in m:
272
+ # Extract number if present
273
+ match = re.search(r'top\s+(\d+)', m)
274
+ limit = match.group(1) if match else "10"
275
+ return f"""
276
+ SELECT *
277
+ FROM analytics.fact_table
278
+ ORDER BY metric_col DESC
279
+ LIMIT {limit};
280
+ """
281
+
282
+ if "count" in m:
283
+ return """
284
+ SELECT
285
+ category_col,
286
+ COUNT(*) AS count
287
+ FROM analytics.fact_table
288
+ GROUP BY 1
289
+ ORDER BY 2 DESC
290
+ LIMIT 100;
291
+ """
292
+
293
+ # Default fallback
294
+ return """
295
+ SELECT *
296
+ FROM analytics.fact_table
297
+ LIMIT 100;
298
+ """
299
+
300
  def run(self, message: str) -> pd.DataFrame:
301
+ """
302
+ Execute SQL query from natural language or SQL statement.
303
+
304
+ Args:
305
+ message: Natural language query or SQL statement
306
+
307
+ Returns:
308
+ DataFrame with query results
309
+
310
+ Raises:
311
+ SQLToolError: If query execution fails
312
+ """
313
  try:
314
+ # Convert to SQL
315
+ sql = self._nl_to_sql(message)
316
+ logger.info(f"Generated SQL query (first 200 chars): {sql[:200]}")
317
+
318
+ # Validate SQL
319
+ is_valid, error_msg = self._validate_sql(sql)
320
+ if not is_valid:
321
+ raise SQLToolError(f"Invalid SQL query: {error_msg}")
322
+
323
+ # Log query attempt
324
+ self.tracer.trace_event("sql_query", {
325
+ "sql": sql[:1000], # Limit logged SQL length
326
+ "backend": self.backend,
327
+ "message": message[:500]
328
+ })
329
+
330
+ # Execute based on backend
331
+ if self.backend == "bigquery":
332
+ result = self._execute_bigquery(sql)
333
+ else: # motherduck
334
+ result = self._execute_duckdb(sql)
335
+
336
+ # Validate result
337
+ if not isinstance(result, pd.DataFrame):
338
+ raise SQLToolError("Query did not return a DataFrame")
339
+
340
+ # Check result size
341
+ if len(result) > MAX_RESULT_ROWS:
342
+ logger.warning(f"Result truncated from {len(result)} to {MAX_RESULT_ROWS} rows")
343
+ result = result.head(MAX_RESULT_ROWS)
344
+
345
+ logger.info(f"Query successful: {len(result)} rows, {len(result.columns)} columns")
346
+ self.tracer.trace_event("sql_success", {
347
+ "rows": len(result),
348
+ "columns": len(result.columns)
349
+ })
350
+
351
+ return result
352
+
353
+ except SQLToolError:
354
+ raise
355
+ except Exception as e:
356
+ error_msg = f"SQL execution failed: {str(e)}"
357
+ logger.error(error_msg)
358
+ self.tracer.trace_event("sql_error", {"error": error_msg})
359
+ raise SQLToolError(error_msg) from e
360
+
361
+ def _execute_bigquery(self, sql: str) -> pd.DataFrame:
362
+ """Execute query on BigQuery."""
363
+ try:
364
+ query_job = self.client.query(sql)
365
+ df = query_job.to_dataframe()
366
+ return df
367
+ except Exception as e:
368
+ raise SQLToolError(f"BigQuery execution error: {str(e)}") from e
369
+
370
+ def _execute_duckdb(self, sql: str) -> pd.DataFrame:
371
+ """Execute query on DuckDB/MotherDuck."""
372
+ try:
373
+ result = self.client.execute(sql)
374
+ df = result.fetch_df()
375
+ return df
376
+ except Exception as e:
377
+ raise SQLToolError(f"DuckDB execution error: {str(e)}") from e
378
+
379
+ def test_connection(self) -> bool:
380
+ """Test database connection."""
381
+ try:
382
+ test_query = "SELECT 1 AS test"
383
+ result = self.run(test_query)
384
+ return len(result) == 1 and result.iloc[0, 0] == 1
385
+ except Exception as e:
386
+ logger.error(f"Connection test failed: {e}")
387
+ return False