AshenH commited on
Commit
e002acf
·
verified ·
1 Parent(s): e3b4d13

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +55 -28
tools/sql_tool.py CHANGED
@@ -5,59 +5,86 @@ from typing import Optional
5
  from utils.config import AppConfig
6
  from utils.tracing import Tracer
7
 
 
8
  class SQLTool:
9
  def __init__(self, cfg: AppConfig, tracer: Tracer):
10
  self.cfg = cfg
11
  self.tracer = tracer
12
  self.backend = cfg.sql_backend # "bigquery" or "motherduck"
 
13
  if self.backend == "bigquery":
14
  from google.cloud import bigquery
15
  from google.oauth2 import service_account
 
16
  key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
17
  if not key_json:
18
  raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
19
- creds = service_account.Credentials.from_service_account_info(
20
- eval(key_json) if key_json.strip().startswith("{") else {}
21
- )
 
 
 
 
 
 
22
  self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
23
- elif self.backend == "motherduck":
24
- import duckdb
25
- token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
26
- db_name = self.cfg.motherduck_db or "default"
27
-
28
- # Start a plain DuckDB connection
29
- self.client = duckdb.connect()
30
-
31
- # Ensure the MotherDuck extension is available and loaded
32
- # (DuckDB will download it automatically in this environment)
33
- self.client.execute("INSTALL motherduck;")
34
- self.client.execute("LOAD motherduck;")
35
-
36
- # Provide token and attach the remote MotherDuck database as 'md'
37
- if not token:
38
- raise RuntimeError("Missing MOTHERDUCK_TOKEN")
39
- self.client.execute(f"SET motherduck_token='{token}';")
40
- self.client.execute(f"ATTACH 'md:/{db_name}' AS md;")
41
- self.client.execute("USE md;") # subsequent queries run against 'md' by default
 
42
  else:
43
  raise RuntimeError("Unknown SQL backend")
44
 
45
  def _nl_to_sql(self, message: str) -> str:
46
- # Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
47
- # Expect users to include table names. Example: "avg revenue by month from dataset.sales"
 
 
 
48
  m = message.lower()
 
 
49
  if "avg" in m and " by " in m:
50
- return "-- Example template; edit me\nSELECT DATE_TRUNC(month, date_col) AS month, AVG(metric) AS avg_metric FROM dataset.table GROUP BY 1 ORDER BY 1;"
51
- # fallback: pass-through if user typed SQL explicitly
 
 
 
 
 
 
 
52
  if re.match(r"^\s*select ", m):
53
  return message
54
- return "SELECT * FROM dataset.table LIMIT 100;"
 
 
55
 
56
  def run(self, message: str) -> pd.DataFrame:
57
  sql = self._nl_to_sql(message)
58
  self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
 
59
  if self.backend == "bigquery":
60
  df = self.client.query(sql).to_dataframe()
61
  else:
 
62
  df = self.client.execute(sql).fetch_df()
63
- return df
 
 
5
  from utils.config import AppConfig
6
  from utils.tracing import Tracer
7
 
8
+
9
  class SQLTool:
10
  def __init__(self, cfg: AppConfig, tracer: Tracer):
11
  self.cfg = cfg
12
  self.tracer = tracer
13
  self.backend = cfg.sql_backend # "bigquery" or "motherduck"
14
+
15
  if self.backend == "bigquery":
16
  from google.cloud import bigquery
17
  from google.oauth2 import service_account
18
+
19
  key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
20
  if not key_json:
21
  raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
22
+
23
+ # Accept full JSON string from Space Secret
24
+ if key_json.strip().startswith("{"):
25
+ import json
26
+ info = json.loads(key_json)
27
+ else:
28
+ info = {}
29
+
30
+ creds = service_account.Credentials.from_service_account_info(info)
31
  self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
32
+
33
+ elif self.backend == "motherduck":
34
+ import duckdb
35
+
36
+ token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
37
+ db_name = self.cfg.motherduck_db or "default"
38
+ if not token:
39
+ raise RuntimeError("Missing MOTHERDUCK_TOKEN")
40
+
41
+ # Plain DuckDB connection
42
+ self.client = duckdb.connect()
43
+
44
+ # Ensure MotherDuck extension is available and loaded
45
+ self.client.execute("INSTALL motherduck;")
46
+ self.client.execute("LOAD motherduck;")
47
+
48
+ # Attach the remote MotherDuck database and use it
49
+ self.client.execute(f"SET motherduck_token='{token}';")
50
+ self.client.execute(f"ATTACH 'md:/{db_name}' AS md;")
51
+ self.client.execute("USE md;") # subsequent queries run against 'md' by default
52
  else:
53
  raise RuntimeError("Unknown SQL backend")
54
 
55
  def _nl_to_sql(self, message: str) -> str:
56
+ """
57
+ Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
58
+ Expect users to include table names. Example:
59
+ "avg metric by month from analytics.events"
60
+ """
61
  m = message.lower()
62
+
63
+ # Very basic template example (edit to your tables/columns)
64
  if "avg" in m and " by " in m:
65
+ return (
66
+ "-- Example template; edit me\n"
67
+ "SELECT DATE_TRUNC('month', date_col) AS month, "
68
+ "AVG(metric) AS avg_metric "
69
+ "FROM analytics.table "
70
+ "GROUP BY 1 ORDER BY 1;"
71
+ )
72
+
73
+ # Pass-through if the user typed SQL explicitly
74
  if re.match(r"^\s*select ", m):
75
  return message
76
+
77
+ # Fallback
78
+ return "SELECT * FROM analytics.table LIMIT 100;"
79
 
80
  def run(self, message: str) -> pd.DataFrame:
81
  sql = self._nl_to_sql(message)
82
  self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
83
+
84
  if self.backend == "bigquery":
85
  df = self.client.query(sql).to_dataframe()
86
  else:
87
+ # DuckDB (MotherDuck): fetch_df returns a pandas DataFrame
88
  df = self.client.execute(sql).fetch_df()
89
+
90
+ return df