|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import logging |
|
|
import pandas as pd |
|
|
from typing import Optional |
|
|
from utils.config import AppConfig |
|
|
from utils.tracing import Tracer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"} |
|
|
MAX_QUERY_LENGTH = 50000 |
|
|
MAX_RESULT_ROWS = 100000 |
|
|
|
|
|
|
|
|
class SQLToolError(Exception): |
|
|
"""Custom exception for SQL tool errors.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class SQLTool: |
|
|
""" |
|
|
SQL execution tool supporting BigQuery and MotherDuck backends. |
|
|
Includes input validation, error handling, and secure query execution. |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg: AppConfig, tracer: Tracer): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self.backend = cfg.sql_backend |
|
|
self.client = None |
|
|
|
|
|
logger.info(f"Initializing SQLTool with backend: {self.backend}") |
|
|
|
|
|
try: |
|
|
if self.backend == "bigquery": |
|
|
self._init_bigquery() |
|
|
elif self.backend == "motherduck": |
|
|
self._init_motherduck() |
|
|
else: |
|
|
raise SQLToolError(f"Unknown SQL backend: {self.backend}") |
|
|
|
|
|
logger.info(f"SQLTool initialized successfully with {self.backend}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize SQLTool: {e}") |
|
|
raise SQLToolError(f"SQL backend initialization failed: {e}") from e |
|
|
|
|
|
def _init_bigquery(self): |
|
|
"""Initialize BigQuery client with service account credentials.""" |
|
|
try: |
|
|
from google.cloud import bigquery |
|
|
from google.oauth2 import service_account |
|
|
|
|
|
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON") |
|
|
if not key_json: |
|
|
raise SQLToolError( |
|
|
"Missing GCP_SERVICE_ACCOUNT_JSON environment variable. " |
|
|
"Please configure BigQuery credentials." |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
if key_json.strip().startswith("{"): |
|
|
info = json.loads(key_json) |
|
|
else: |
|
|
|
|
|
with open(key_json, 'r') as f: |
|
|
info = json.load(f) |
|
|
except json.JSONDecodeError as e: |
|
|
raise SQLToolError(f"Invalid JSON in GCP_SERVICE_ACCOUNT_JSON: {e}") |
|
|
except FileNotFoundError: |
|
|
raise SQLToolError(f"GCP service account file not found: {key_json}") |
|
|
|
|
|
|
|
|
required_fields = ["type", "project_id", "private_key", "client_email"] |
|
|
missing = [f for f in required_fields if f not in info] |
|
|
if missing: |
|
|
raise SQLToolError( |
|
|
f"GCP service account JSON missing required fields: {missing}" |
|
|
) |
|
|
|
|
|
creds = service_account.Credentials.from_service_account_info(info) |
|
|
project = self.cfg.gcp_project or info.get("project_id") |
|
|
|
|
|
if not project: |
|
|
raise SQLToolError("GCP project ID not specified in config or credentials") |
|
|
|
|
|
self.client = bigquery.Client(credentials=creds, project=project) |
|
|
logger.info(f"BigQuery client initialized for project: {project}") |
|
|
|
|
|
except ImportError as e: |
|
|
raise SQLToolError( |
|
|
"BigQuery dependencies not installed. " |
|
|
"Install with: pip install google-cloud-bigquery" |
|
|
) from e |
|
|
|
|
|
def _init_motherduck(self): |
|
|
"""Initialize MotherDuck/DuckDB client with version validation.""" |
|
|
try: |
|
|
import duckdb |
|
|
|
|
|
|
|
|
version = duckdb.__version__ |
|
|
logger.info(f"DuckDB version: {version}") |
|
|
|
|
|
|
|
|
if not version.startswith("1.3"): |
|
|
logger.warning( |
|
|
f"DuckDB {version} detected. Recommended: 1.3.x for MotherDuck compatibility. " |
|
|
"Some features may not work as expected." |
|
|
) |
|
|
|
|
|
|
|
|
token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip() |
|
|
if not token: |
|
|
raise SQLToolError( |
|
|
"Missing MOTHERDUCK_TOKEN. " |
|
|
"Get your token from: https://motherduck.com/docs/key-tasks/authenticating-to-motherduck" |
|
|
) |
|
|
|
|
|
db_name = (self.cfg.motherduck_db or "workspace").strip() |
|
|
allow_create = os.getenv("ALLOW_CREATE_DB", "true").lower() == "true" |
|
|
|
|
|
|
|
|
if db_name in RESERVED_MD_WORKSPACE_NAMES: |
|
|
|
|
|
connection_string = f"md:?motherduck_token={token}" |
|
|
logger.info("Connecting to MotherDuck workspace") |
|
|
self.client = duckdb.connect(connection_string) |
|
|
else: |
|
|
|
|
|
try: |
|
|
connection_string = f"md:{db_name}?motherduck_token={token}" |
|
|
logger.info(f"Connecting to MotherDuck database: {db_name}") |
|
|
self.client = duckdb.connect(connection_string) |
|
|
except Exception as db_err: |
|
|
logger.warning(f"Direct connection to '{db_name}' failed: {db_err}") |
|
|
|
|
|
|
|
|
connection_string = f"md:?motherduck_token={token}" |
|
|
self.client = duckdb.connect(connection_string) |
|
|
self._ensure_db_context(db_name, allow_create) |
|
|
|
|
|
|
|
|
try: |
|
|
self.client.execute("SELECT 1").fetchone() |
|
|
logger.info("MotherDuck connection test successful") |
|
|
except Exception as e: |
|
|
raise SQLToolError(f"MotherDuck connection test failed: {e}") |
|
|
|
|
|
except ImportError as e: |
|
|
raise SQLToolError( |
|
|
"DuckDB not installed. Install with: pip install duckdb" |
|
|
) from e |
|
|
|
|
|
def _ensure_db_context(self, db_name: str, allow_create: bool): |
|
|
""" |
|
|
Ensure database context is set for MotherDuck. |
|
|
Creates database if it doesn't exist and allow_create is True. |
|
|
""" |
|
|
if db_name in RESERVED_MD_WORKSPACE_NAMES: |
|
|
return |
|
|
|
|
|
safe_name = self._quote_ident(db_name) |
|
|
|
|
|
|
|
|
try: |
|
|
self.client.execute(f"USE {safe_name};") |
|
|
logger.info(f"Using existing database: {db_name}") |
|
|
return |
|
|
except Exception as use_err: |
|
|
logger.info(f"Database '{db_name}' not found: {use_err}") |
|
|
|
|
|
if not allow_create: |
|
|
raise SQLToolError( |
|
|
f"Database '{db_name}' does not exist and ALLOW_CREATE_DB is disabled. " |
|
|
f"Either create the database manually or set ALLOW_CREATE_DB=true" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info(f"Creating database: {db_name}") |
|
|
self.client.execute(f"CREATE DATABASE IF NOT EXISTS {safe_name};") |
|
|
self.client.execute(f"USE {safe_name};") |
|
|
logger.info(f"Database '{db_name}' created and selected") |
|
|
except Exception as create_err: |
|
|
raise SQLToolError( |
|
|
f"Failed to create database '{db_name}': {create_err}" |
|
|
) from create_err |
|
|
|
|
|
@staticmethod |
|
|
def _quote_ident(name: str) -> str: |
|
|
""" |
|
|
Safely quote SQL identifiers. |
|
|
Replaces non-alphanumeric characters with underscores. |
|
|
""" |
|
|
if not name: |
|
|
return "unnamed" |
|
|
|
|
|
|
|
|
safe = re.sub(r"[^a-zA-Z0-9_]", "_", name) |
|
|
|
|
|
|
|
|
if safe[0].isdigit(): |
|
|
safe = "_" + safe |
|
|
|
|
|
return safe |
|
|
|
|
|
def _validate_sql(self, sql: str) -> tuple[bool, str]: |
|
|
""" |
|
|
Validate SQL query for basic safety. |
|
|
Returns (is_valid, error_message). |
|
|
""" |
|
|
if not sql or not sql.strip(): |
|
|
return False, "Empty SQL query" |
|
|
|
|
|
if len(sql) > MAX_QUERY_LENGTH: |
|
|
return False, f"Query too long (max {MAX_QUERY_LENGTH} characters)" |
|
|
|
|
|
|
|
|
sql_lower = sql.lower() |
|
|
|
|
|
|
|
|
if sql.count(';') > 1: |
|
|
return False, "Multiple SQL statements not allowed" |
|
|
|
|
|
|
|
|
dangerous_patterns = [ |
|
|
(r'\bdrop\s+table\b', "DROP TABLE"), |
|
|
(r'\bdrop\s+database\b', "DROP DATABASE"), |
|
|
(r'\bdelete\s+from\b', "DELETE FROM"), |
|
|
(r'\btruncate\b', "TRUNCATE"), |
|
|
(r'\bexec\s*\(', "EXEC"), |
|
|
(r'\bexecute\s*\(', "EXECUTE"), |
|
|
] |
|
|
|
|
|
for pattern, name in dangerous_patterns: |
|
|
if re.search(pattern, sql_lower): |
|
|
logger.warning(f"Blocked query with {name} pattern") |
|
|
return False, f"Query contains blocked operation: {name}" |
|
|
|
|
|
return True, "" |
|
|
|
|
|
def _nl_to_sql(self, message: str) -> str: |
|
|
""" |
|
|
Convert natural language to SQL query. |
|
|
This is a simple heuristic - replace with proper NL2SQL model for production. |
|
|
""" |
|
|
m = message.lower() |
|
|
|
|
|
|
|
|
if re.match(r'^\s*select\s', m, re.IGNORECASE): |
|
|
return message.strip() |
|
|
|
|
|
|
|
|
if "avg" in m or "average" in m: |
|
|
if "by month" in m or "monthly" in m: |
|
|
return """ |
|
|
SELECT |
|
|
DATE_TRUNC('month', date_col) AS month, |
|
|
AVG(metric_col) AS avg_metric |
|
|
FROM analytics.fact_table |
|
|
GROUP BY 1 |
|
|
ORDER BY 1 DESC |
|
|
LIMIT 100; |
|
|
""" |
|
|
|
|
|
if "top" in m: |
|
|
|
|
|
match = re.search(r'top\s+(\d+)', m) |
|
|
limit = match.group(1) if match else "10" |
|
|
return f""" |
|
|
SELECT * |
|
|
FROM analytics.fact_table |
|
|
ORDER BY metric_col DESC |
|
|
LIMIT {limit}; |
|
|
""" |
|
|
|
|
|
if "count" in m: |
|
|
return """ |
|
|
SELECT |
|
|
category_col, |
|
|
COUNT(*) AS count |
|
|
FROM analytics.fact_table |
|
|
GROUP BY 1 |
|
|
ORDER BY 2 DESC |
|
|
LIMIT 100; |
|
|
""" |
|
|
|
|
|
|
|
|
return """ |
|
|
SELECT * |
|
|
FROM analytics.fact_table |
|
|
LIMIT 100; |
|
|
""" |
|
|
|
|
|
def run(self, message: str) -> pd.DataFrame: |
|
|
""" |
|
|
Execute SQL query from natural language or SQL statement. |
|
|
|
|
|
Args: |
|
|
message: Natural language query or SQL statement |
|
|
|
|
|
Returns: |
|
|
DataFrame with query results |
|
|
|
|
|
Raises: |
|
|
SQLToolError: If query execution fails |
|
|
""" |
|
|
try: |
|
|
|
|
|
sql = self._nl_to_sql(message) |
|
|
logger.info(f"Generated SQL query (first 200 chars): {sql[:200]}") |
|
|
|
|
|
|
|
|
is_valid, error_msg = self._validate_sql(sql) |
|
|
if not is_valid: |
|
|
raise SQLToolError(f"Invalid SQL query: {error_msg}") |
|
|
|
|
|
|
|
|
self.tracer.trace_event("sql_query", { |
|
|
"sql": sql[:1000], |
|
|
"backend": self.backend, |
|
|
"message": message[:500] |
|
|
}) |
|
|
|
|
|
|
|
|
if self.backend == "bigquery": |
|
|
result = self._execute_bigquery(sql) |
|
|
else: |
|
|
result = self._execute_duckdb(sql) |
|
|
|
|
|
|
|
|
if not isinstance(result, pd.DataFrame): |
|
|
raise SQLToolError("Query did not return a DataFrame") |
|
|
|
|
|
|
|
|
if len(result) > MAX_RESULT_ROWS: |
|
|
logger.warning(f"Result truncated from {len(result)} to {MAX_RESULT_ROWS} rows") |
|
|
result = result.head(MAX_RESULT_ROWS) |
|
|
|
|
|
logger.info(f"Query successful: {len(result)} rows, {len(result.columns)} columns") |
|
|
self.tracer.trace_event("sql_success", { |
|
|
"rows": len(result), |
|
|
"columns": len(result.columns) |
|
|
}) |
|
|
|
|
|
return result |
|
|
|
|
|
except SQLToolError: |
|
|
raise |
|
|
except Exception as e: |
|
|
error_msg = f"SQL execution failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
self.tracer.trace_event("sql_error", {"error": error_msg}) |
|
|
raise SQLToolError(error_msg) from e |
|
|
|
|
|
def _execute_bigquery(self, sql: str) -> pd.DataFrame: |
|
|
"""Execute query on BigQuery.""" |
|
|
try: |
|
|
query_job = self.client.query(sql) |
|
|
df = query_job.to_dataframe() |
|
|
return df |
|
|
except Exception as e: |
|
|
raise SQLToolError(f"BigQuery execution error: {str(e)}") from e |
|
|
|
|
|
def _execute_duckdb(self, sql: str) -> pd.DataFrame: |
|
|
"""Execute query on DuckDB/MotherDuck.""" |
|
|
try: |
|
|
result = self.client.execute(sql) |
|
|
df = result.fetch_df() |
|
|
return df |
|
|
except Exception as e: |
|
|
raise SQLToolError(f"DuckDB execution error: {str(e)}") from e |
|
|
|
|
|
def test_connection(self) -> bool: |
|
|
"""Test database connection.""" |
|
|
try: |
|
|
test_query = "SELECT 1 AS test" |
|
|
result = self.run(test_query) |
|
|
return len(result) == 1 and result.iloc[0, 0] == 1 |
|
|
except Exception as e: |
|
|
logger.error(f"Connection test failed: {e}") |
|
|
return False |