ALM_LLM / tools /sql_tool.py
AshenH's picture
Update tools/sql_tool.py
6c6d38f verified
raw
history blame
14.2 kB
# space/tools/sql_tool.py
import os
import re
import json
import logging
import pandas as pd
from typing import Optional
from utils.config import AppConfig
from utils.tracing import Tracer
logger = logging.getLogger(__name__)
RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"}
MAX_QUERY_LENGTH = 50000
MAX_RESULT_ROWS = 100000
class SQLToolError(Exception):
"""Custom exception for SQL tool errors."""
pass
class SQLTool:
"""
SQL execution tool supporting BigQuery and MotherDuck backends.
Includes input validation, error handling, and secure query execution.
"""
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self.backend = cfg.sql_backend
self.client = None
logger.info(f"Initializing SQLTool with backend: {self.backend}")
try:
if self.backend == "bigquery":
self._init_bigquery()
elif self.backend == "motherduck":
self._init_motherduck()
else:
raise SQLToolError(f"Unknown SQL backend: {self.backend}")
logger.info(f"SQLTool initialized successfully with {self.backend}")
except Exception as e:
logger.error(f"Failed to initialize SQLTool: {e}")
raise SQLToolError(f"SQL backend initialization failed: {e}") from e
def _init_bigquery(self):
"""Initialize BigQuery client with service account credentials."""
try:
from google.cloud import bigquery
from google.oauth2 import service_account
key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
if not key_json:
raise SQLToolError(
"Missing GCP_SERVICE_ACCOUNT_JSON environment variable. "
"Please configure BigQuery credentials."
)
# Parse credentials
try:
if key_json.strip().startswith("{"):
info = json.loads(key_json)
else:
# Assume it's a file path
with open(key_json, 'r') as f:
info = json.load(f)
except json.JSONDecodeError as e:
raise SQLToolError(f"Invalid JSON in GCP_SERVICE_ACCOUNT_JSON: {e}")
except FileNotFoundError:
raise SQLToolError(f"GCP service account file not found: {key_json}")
# Validate required fields
required_fields = ["type", "project_id", "private_key", "client_email"]
missing = [f for f in required_fields if f not in info]
if missing:
raise SQLToolError(
f"GCP service account JSON missing required fields: {missing}"
)
creds = service_account.Credentials.from_service_account_info(info)
project = self.cfg.gcp_project or info.get("project_id")
if not project:
raise SQLToolError("GCP project ID not specified in config or credentials")
self.client = bigquery.Client(credentials=creds, project=project)
logger.info(f"BigQuery client initialized for project: {project}")
except ImportError as e:
raise SQLToolError(
"BigQuery dependencies not installed. "
"Install with: pip install google-cloud-bigquery"
) from e
def _init_motherduck(self):
"""Initialize MotherDuck/DuckDB client with version validation."""
try:
import duckdb
# Version compatibility check - be more flexible
version = duckdb.__version__
logger.info(f"DuckDB version: {version}")
# Warn if not on recommended version, but don't fail
if not version.startswith("1.3"):
logger.warning(
f"DuckDB {version} detected. Recommended: 1.3.x for MotherDuck compatibility. "
"Some features may not work as expected."
)
# Get configuration
token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
if not token:
raise SQLToolError(
"Missing MOTHERDUCK_TOKEN. "
"Get your token from: https://motherduck.com/docs/key-tasks/authenticating-to-motherduck"
)
db_name = (self.cfg.motherduck_db or "workspace").strip()
allow_create = os.getenv("ALLOW_CREATE_DB", "true").lower() == "true"
# Connect based on database name
if db_name in RESERVED_MD_WORKSPACE_NAMES:
# Workspace mode - no specific database context
connection_string = f"md:?motherduck_token={token}"
logger.info("Connecting to MotherDuck workspace")
self.client = duckdb.connect(connection_string)
else:
# Try connecting to specific database
try:
connection_string = f"md:{db_name}?motherduck_token={token}"
logger.info(f"Connecting to MotherDuck database: {db_name}")
self.client = duckdb.connect(connection_string)
except Exception as db_err:
logger.warning(f"Direct connection to '{db_name}' failed: {db_err}")
# Fallback: connect to workspace and setup database
connection_string = f"md:?motherduck_token={token}"
self.client = duckdb.connect(connection_string)
self._ensure_db_context(db_name, allow_create)
# Test connection
try:
self.client.execute("SELECT 1").fetchone()
logger.info("MotherDuck connection test successful")
except Exception as e:
raise SQLToolError(f"MotherDuck connection test failed: {e}")
except ImportError as e:
raise SQLToolError(
"DuckDB not installed. Install with: pip install duckdb"
) from e
def _ensure_db_context(self, db_name: str, allow_create: bool):
"""
Ensure database context is set for MotherDuck.
Creates database if it doesn't exist and allow_create is True.
"""
if db_name in RESERVED_MD_WORKSPACE_NAMES:
return
safe_name = self._quote_ident(db_name)
# Try to USE the database first
try:
self.client.execute(f"USE {safe_name};")
logger.info(f"Using existing database: {db_name}")
return
except Exception as use_err:
logger.info(f"Database '{db_name}' not found: {use_err}")
if not allow_create:
raise SQLToolError(
f"Database '{db_name}' does not exist and ALLOW_CREATE_DB is disabled. "
f"Either create the database manually or set ALLOW_CREATE_DB=true"
)
# Attempt to create and use the database
try:
logger.info(f"Creating database: {db_name}")
self.client.execute(f"CREATE DATABASE IF NOT EXISTS {safe_name};")
self.client.execute(f"USE {safe_name};")
logger.info(f"Database '{db_name}' created and selected")
except Exception as create_err:
raise SQLToolError(
f"Failed to create database '{db_name}': {create_err}"
) from create_err
@staticmethod
def _quote_ident(name: str) -> str:
"""
Safely quote SQL identifiers.
Replaces non-alphanumeric characters with underscores.
"""
if not name:
return "unnamed"
# Remove dangerous characters
safe = re.sub(r"[^a-zA-Z0-9_]", "_", name)
# Ensure it doesn't start with a number
if safe[0].isdigit():
safe = "_" + safe
return safe
def _validate_sql(self, sql: str) -> tuple[bool, str]:
"""
Validate SQL query for basic safety.
Returns (is_valid, error_message).
"""
if not sql or not sql.strip():
return False, "Empty SQL query"
if len(sql) > MAX_QUERY_LENGTH:
return False, f"Query too long (max {MAX_QUERY_LENGTH} characters)"
# Dangerous patterns check
sql_lower = sql.lower()
# Block multiple statements (simple check)
if sql.count(';') > 1:
return False, "Multiple SQL statements not allowed"
# Block dangerous keywords in non-SELECT queries
dangerous_patterns = [
(r'\bdrop\s+table\b', "DROP TABLE"),
(r'\bdrop\s+database\b', "DROP DATABASE"),
(r'\bdelete\s+from\b', "DELETE FROM"),
(r'\btruncate\b', "TRUNCATE"),
(r'\bexec\s*\(', "EXEC"),
(r'\bexecute\s*\(', "EXECUTE"),
]
for pattern, name in dangerous_patterns:
if re.search(pattern, sql_lower):
logger.warning(f"Blocked query with {name} pattern")
return False, f"Query contains blocked operation: {name}"
return True, ""
def _nl_to_sql(self, message: str) -> str:
"""
Convert natural language to SQL query.
This is a simple heuristic - replace with proper NL2SQL model for production.
"""
m = message.lower()
# If it's already SQL, return as-is (after validation)
if re.match(r'^\s*select\s', m, re.IGNORECASE):
return message.strip()
# Template-based generation (customize for your schema)
if "avg" in m or "average" in m:
if "by month" in m or "monthly" in m:
return """
SELECT
DATE_TRUNC('month', date_col) AS month,
AVG(metric_col) AS avg_metric
FROM analytics.fact_table
GROUP BY 1
ORDER BY 1 DESC
LIMIT 100;
"""
if "top" in m:
# Extract number if present
match = re.search(r'top\s+(\d+)', m)
limit = match.group(1) if match else "10"
return f"""
SELECT *
FROM analytics.fact_table
ORDER BY metric_col DESC
LIMIT {limit};
"""
if "count" in m:
return """
SELECT
category_col,
COUNT(*) AS count
FROM analytics.fact_table
GROUP BY 1
ORDER BY 2 DESC
LIMIT 100;
"""
# Default fallback
return """
SELECT *
FROM analytics.fact_table
LIMIT 100;
"""
def run(self, message: str) -> pd.DataFrame:
"""
Execute SQL query from natural language or SQL statement.
Args:
message: Natural language query or SQL statement
Returns:
DataFrame with query results
Raises:
SQLToolError: If query execution fails
"""
try:
# Convert to SQL
sql = self._nl_to_sql(message)
logger.info(f"Generated SQL query (first 200 chars): {sql[:200]}")
# Validate SQL
is_valid, error_msg = self._validate_sql(sql)
if not is_valid:
raise SQLToolError(f"Invalid SQL query: {error_msg}")
# Log query attempt
self.tracer.trace_event("sql_query", {
"sql": sql[:1000], # Limit logged SQL length
"backend": self.backend,
"message": message[:500]
})
# Execute based on backend
if self.backend == "bigquery":
result = self._execute_bigquery(sql)
else: # motherduck
result = self._execute_duckdb(sql)
# Validate result
if not isinstance(result, pd.DataFrame):
raise SQLToolError("Query did not return a DataFrame")
# Check result size
if len(result) > MAX_RESULT_ROWS:
logger.warning(f"Result truncated from {len(result)} to {MAX_RESULT_ROWS} rows")
result = result.head(MAX_RESULT_ROWS)
logger.info(f"Query successful: {len(result)} rows, {len(result.columns)} columns")
self.tracer.trace_event("sql_success", {
"rows": len(result),
"columns": len(result.columns)
})
return result
except SQLToolError:
raise
except Exception as e:
error_msg = f"SQL execution failed: {str(e)}"
logger.error(error_msg)
self.tracer.trace_event("sql_error", {"error": error_msg})
raise SQLToolError(error_msg) from e
def _execute_bigquery(self, sql: str) -> pd.DataFrame:
"""Execute query on BigQuery."""
try:
query_job = self.client.query(sql)
df = query_job.to_dataframe()
return df
except Exception as e:
raise SQLToolError(f"BigQuery execution error: {str(e)}") from e
def _execute_duckdb(self, sql: str) -> pd.DataFrame:
"""Execute query on DuckDB/MotherDuck."""
try:
result = self.client.execute(sql)
df = result.fetch_df()
return df
except Exception as e:
raise SQLToolError(f"DuckDB execution error: {str(e)}") from e
def test_connection(self) -> bool:
"""Test database connection."""
try:
test_query = "SELECT 1 AS test"
result = self.run(test_query)
return len(result) == 1 and result.iloc[0, 0] == 1
except Exception as e:
logger.error(f"Connection test failed: {e}")
return False