File size: 5,467 Bytes
b07564d
f4dc602
 
b07564d
f4dc602
 
 
 
e002acf
85b8a4e
 
 
f4dc602
 
 
 
 
e002acf
85b8a4e
f4dc602
 
 
2336094
f4dc602
 
 
e002acf
b07564d
e002acf
f4dc602
e002acf
85b8a4e
e002acf
b07564d
2336094
 
9d6bac9
 
b07564d
 
9d6bac9
e002acf
2336094
 
 
e002acf
 
 
85b8a4e
 
 
2336094
85b8a4e
 
 
 
 
 
 
 
2336094
 
f4dc602
2336094
f4dc602
2336094
 
 
 
85b8a4e
2336094
85b8a4e
 
 
 
 
2336094
 
85b8a4e
2336094
 
 
 
 
 
85b8a4e
 
 
 
 
 
 
 
 
 
 
2336094
 
 
e002acf
85b8a4e
e002acf
85b8a4e
2336094
 
 
 
f4dc602
e002acf
2336094
f4dc602
e002acf
2336094
 
 
 
 
b07564d
e002acf
 
2336094
85b8a4e
f4dc602
e002acf
2336094
e002acf
f4dc602
2336094
f4dc602
 
b07564d
 
 
 
e002acf
f4dc602
54614e9
f4dc602
54614e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# space/tools/sql_tool.py
import os
import re
import json
import pandas as pd
from utils.config import AppConfig
from utils.tracing import Tracer


RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"}  # treat these as workspace/no-DB context


class SQLTool:
    def __init__(self, cfg: AppConfig, tracer: Tracer):
        self.cfg = cfg
        self.tracer = tracer
        self.backend = cfg.sql_backend  # "bigquery" or "motherduck"

        # ---------------- BIGQUERY BACKEND ----------------
        if self.backend == "bigquery":
            from google.cloud import bigquery
            from google.oauth2 import service_account

            key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
            if not key_json:
                raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")

            info = json.loads(key_json) if key_json.strip().startswith("{") else {}
            creds = service_account.Credentials.from_service_account_info(info)
            self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)

        # ---------------- MOTHERDUCK BACKEND ----------------
        elif self.backend == "motherduck":
            import duckdb

            # MotherDuck extension compatibility: widely supported ABI is DuckDB 1.3.2
            if not duckdb.__version__.startswith("1.3.2"):
                raise RuntimeError(
                    f"Incompatible DuckDB version {duckdb.__version__}. "
                    "Pin duckdb==1.3.2 in requirements.txt and redeploy."
                )

            token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
            db_name = (self.cfg.motherduck_db or "workspace").strip()
            allow_create = (os.getenv("ALLOW_CREATE_DB", "true").lower() == "true")
            if not token:
                raise RuntimeError("Missing MOTHERDUCK_TOKEN")

            # Workspace vs concrete DB handling
            if db_name in RESERVED_MD_WORKSPACE_NAMES:
                # Connect to workspace; caller should fully-qualify tables if needed
                self.client = duckdb.connect(f"md:?motherduck_token={token}")
                # No USE/CREATE in workspace mode
            else:
                # Try direct connection to the DB (preferred)
                try:
                    self.client = duckdb.connect(f"md:{db_name}?motherduck_token={token}")
                except Exception:
                    # Fallback: connect to workspace, then USE/CREATE the DB if permitted
                    self.client = duckdb.connect(f"md:?motherduck_token={token}")
                    self._ensure_db_context(db_name, allow_create)

        else:
            raise RuntimeError(f"Unknown SQL backend: {self.backend}")

    # ----- MotherDuck helpers -----
    def _ensure_db_context(self, db_name: str, allow_create: bool):
        """
        Try to USE the target DB; if it doesn't exist and allow_create=True, create it and USE it.
        Skips reserved workspace names.
        """
        if db_name in RESERVED_MD_WORKSPACE_NAMES:
            # No-op for workspace/default
            return

        # Attempt USE first
        try:
            self.client.execute(f"USE {self._quote_ident(db_name)};")
            return
        except Exception as use_err:
            if not allow_create:
                raise RuntimeError(
                    f"Database '{db_name}' not found and ALLOW_CREATE_DB is false. "
                    f"Original error: {use_err}"
                )

        # Attempt CREATE then USE
        try:
            # CREATE DATABASE <name>; is supported on MotherDuck for valid names (not 'default')
            self.client.execute(f"CREATE DATABASE {self._quote_ident(db_name)};")
            self.client.execute(f"USE {self._quote_ident(db_name)};")
        except Exception as create_err:
            raise RuntimeError(
                f"Could not create or use database '{db_name}'. "
                f"Original errors: CREATE: {create_err}"
            )

    @staticmethod
    def _quote_ident(name: str) -> str:
        """
        Very light identifier quoting. Replace non [a-zA-Z0-9_] with underscore.
        """
        safe = re.sub(r"[^a-zA-Z0-9_]", "_", (name or ""))
        return safe

    # ----- NL → SQL heuristic (toy example; edit to your schema) -----
    def _nl_to_sql(self, message: str) -> str:
        m = message.lower()

        # Example DuckDB/MotherDuck flavor of DATE_TRUNC
        if "avg" in m and " by " in m:
            return (
                "-- Example template; edit to your schema/columns\n"
                "SELECT DATE_TRUNC('month', date_col) AS month,\n"
                "       AVG(metric) AS avg_metric\n"
                "FROM analytics.table\n"
                "GROUP BY 1\n"
                "ORDER BY 1;"
            )

        # If user typed SQL already, run it as-is
        if re.match(r"^\\s*select ", m):
            return message

        # Fallback
        return "SELECT * FROM analytics.table LIMIT 100;"

    # ----- Execute -----
    def run(self, message: str) -> pd.DataFrame:
        sql = self._nl_to_sql(message)
        try:
            self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
        except Exception:
            pass

        if self.backend == "bigquery":
            return self.client.query(sql).to_dataframe()
        else:
            return self.client.execute(sql).fetch_df()