AshenH commited on
Commit
f4dc602
·
verified ·
1 Parent(s): 4b20efc

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +48 -48
tools/sql_tool.py CHANGED
@@ -1,49 +1,49 @@
1
- import os
2
- import re
3
- import pandas as pd
4
- from typing import Optional
5
- from ..utils.config import AppConfig
6
- from ..utils.tracing import Tracer
7
-
8
- class SQLTool:
9
- def __init__(self, cfg: AppConfig, tracer: Tracer):
10
- self.cfg = cfg
11
- self.tracer = tracer
12
- self.backend = cfg.sql_backend # "bigquery" or "motherduck"
13
- if self.backend == "bigquery":
14
- from google.cloud import bigquery
15
- from google.oauth2 import service_account
16
- key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
17
- if not key_json:
18
- raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
19
- creds = service_account.Credentials.from_service_account_info(
20
- eval(key_json) if key_json.strip().startswith("{") else {}
21
- )
22
- self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
23
- elif self.backend == "motherduck":
24
- import duckdb
25
- token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
26
- db_name = self.cfg.motherduck_db or "default"
27
- self.client = duckdb.connect(f"md:/{db_name}?motherduck_token={token}")
28
- else:
29
- raise RuntimeError("Unknown SQL backend")
30
-
31
- def _nl_to_sql(self, message: str) -> str:
32
- # Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
33
- # Expect users to include table names. Example: "avg revenue by month from dataset.sales"
34
- m = message.lower()
35
- if "avg" in m and " by " in m:
36
- return "-- Example template; edit me\nSELECT DATE_TRUNC(month, date_col) AS month, AVG(metric) AS avg_metric FROM dataset.table GROUP BY 1 ORDER BY 1;"
37
- # fallback: pass-through if user typed SQL explicitly
38
- if re.match(r"^\s*select ", m):
39
- return message
40
- return "SELECT * FROM dataset.table LIMIT 100;"
41
-
42
- def run(self, message: str) -> pd.DataFrame:
43
- sql = self._nl_to_sql(message)
44
- self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
45
- if self.backend == "bigquery":
46
- df = self.client.query(sql).to_dataframe()
47
- else:
48
- df = self.client.execute(sql).fetch_df()
49
  return df
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from typing import Optional
5
+ from utils.config import AppConfig
6
+ from utils.tracing import Tracer
7
+
8
+ class SQLTool:
9
+ def __init__(self, cfg: AppConfig, tracer: Tracer):
10
+ self.cfg = cfg
11
+ self.tracer = tracer
12
+ self.backend = cfg.sql_backend # "bigquery" or "motherduck"
13
+ if self.backend == "bigquery":
14
+ from google.cloud import bigquery
15
+ from google.oauth2 import service_account
16
+ key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
17
+ if not key_json:
18
+ raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
19
+ creds = service_account.Credentials.from_service_account_info(
20
+ eval(key_json) if key_json.strip().startswith("{") else {}
21
+ )
22
+ self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
23
+ elif self.backend == "motherduck":
24
+ import duckdb
25
+ token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
26
+ db_name = self.cfg.motherduck_db or "default"
27
+ self.client = duckdb.connect(f"md:/{db_name}?motherduck_token={token}")
28
+ else:
29
+ raise RuntimeError("Unknown SQL backend")
30
+
31
+ def _nl_to_sql(self, message: str) -> str:
32
+ # Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
33
+ # Expect users to include table names. Example: "avg revenue by month from dataset.sales"
34
+ m = message.lower()
35
+ if "avg" in m and " by " in m:
36
+ return "-- Example template; edit me\nSELECT DATE_TRUNC(month, date_col) AS month, AVG(metric) AS avg_metric FROM dataset.table GROUP BY 1 ORDER BY 1;"
37
+ # fallback: pass-through if user typed SQL explicitly
38
+ if re.match(r"^\s*select ", m):
39
+ return message
40
+ return "SELECT * FROM dataset.table LIMIT 100;"
41
+
42
+ def run(self, message: str) -> pd.DataFrame:
43
+ sql = self._nl_to_sql(message)
44
+ self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
45
+ if self.backend == "bigquery":
46
+ df = self.client.query(sql).to_dataframe()
47
+ else:
48
+ df = self.client.execute(sql).fetch_df()
49
  return df