AshenH commited on
Commit
9b3d9a0
·
verified ·
1 Parent(s): 30413d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -59
app.py CHANGED
@@ -1,70 +1,217 @@
1
- # app.py
2
- import os
3
- import pandas as pd
4
  import gradio as gr
 
 
 
 
 
5
 
6
- from tools.sql_tool import SQLTool
7
- from tools.ts_preprocess import build_timeseries
 
 
 
 
8
 
9
- DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- sql_tool = SQLTool(DUCKDB_PATH)
12
- RESOLVED_PATH = sql_tool.get_full_table_path() # e.g., my_db.main.masterdataset_v or main.masterdataset_v
 
 
 
 
 
13
 
14
- INTRO = f"""
15
- ### ALM LLM — Demo
16
 
17
- Connected to **DuckDB** at `{DUCKDB_PATH}`
18
- Using table **{RESOLVED_PATH}** (auto-resolved).
19
- """
20
 
21
- def run_nl(nl_query: str):
22
- if not nl_query or not nl_query.strip():
23
- return pd.DataFrame(), "", "Please enter a query.", pd.DataFrame(), pd.DataFrame()
 
 
 
 
24
  try:
25
- df, sql, why = sql_tool.query_from_nl(nl_query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  except Exception as e:
27
- return pd.DataFrame(), "", f"Error: {e}", pd.DataFrame(), pd.DataFrame()
28
- try:
29
- cf, gap = build_timeseries(df)
30
- except Exception:
31
- cf, gap = pd.DataFrame(), pd.DataFrame()
32
- return df, sql.strip(), why, cf, gap
33
-
34
- def run_sql(sql_text: str):
35
- if not sql_text or not sql_text.strip():
36
- return pd.DataFrame(), "Please paste a SQL statement.", pd.DataFrame(), pd.DataFrame()
 
 
 
37
  try:
38
- df = sql_tool.run_sql(sql_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
- return pd.DataFrame(), f"Error: {e}", pd.DataFrame(), pd.DataFrame()
41
- try:
42
- cf, gap = build_timeseries(df)
43
- except Exception:
44
- cf, gap = pd.DataFrame(), pd.DataFrame()
45
- return df, "OK", cf, gap
46
-
47
- with gr.Blocks(title="ALM LLM") as demo:
48
- gr.Markdown(INTRO)
49
-
50
- with gr.Tab("Ask in Natural Language"):
51
- nl = gr.Textbox(label="Ask a question", placeholder="e.g., show me the top 10 fds by portfolio value", lines=2)
52
- btn = gr.Button("Run")
53
- sql_out = gr.Textbox(label="Generated SQL", interactive=False)
54
- why_out = gr.Textbox(label="Reasoning", interactive=False)
55
- df_out = gr.Dataframe(label="Query Result", interactive=True)
56
- cf_out = gr.Dataframe(label="Projected Cash-Flows (if applicable)", interactive=True)
57
- gap_out = gr.Dataframe(label="Liquidity Gap (monthly)", interactive=True)
58
- btn.click(fn=run_nl, inputs=[nl], outputs=[df_out, sql_out, why_out, cf_out, gap_out])
59
-
60
- with gr.Tab("Run Raw SQL"):
61
- sql_in = gr.Code(label="SQL", language="sql", value=f"SELECT * FROM {RESOLVED_PATH} LIMIT 20;")
62
- btn2 = gr.Button("Execute")
63
- df2 = gr.Dataframe(label="Result", interactive=True)
64
- status = gr.Textbox(label="Status", interactive=False)
65
- cf2 = gr.Dataframe(label="Projected Cash-Flows (if applicable)", interactive=True)
66
- gap2 = gr.Dataframe(label="Liquidity Gap (monthly)", interactive=True)
67
- btn2.click(fn=run_sql, inputs=[sql_in], outputs=[df2, status, cf2, gap2])
68
-
69
- if __name__ == "__main__":
70
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ import statsmodels.api as sm
6
+ from io import StringIO
7
 
8
+ # --- LangChain Imports ---
9
+ from langchain_groq import ChatGroq
10
+ from langchain.agents import AgentExecutor, create_tool_calling_agent
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+ from langchain_core.tools import tool
13
+ from langchain_core.messages import SystemMessage
14
 
15
+ # --- ASSUMPTION ---
16
+ # Assuming you have a file named 'sql_tools.py' in the same directory
17
+ # with your pre-built and decorated @tool functions.
18
+ try:
19
+ from sql_tools import run_duckdb_query, get_table_schema
20
+ except ImportError:
21
+ print("WARNING: Could not import from 'sql_tools.py'.")
22
+ print("Using placeholder functions. Please create 'sql_tools.py'.")
23
+
24
+ # Create placeholder tools if the file is missing, so the app can start
25
+ @tool
26
+ def run_duckdb_query(query: str) -> str:
27
+ """
28
+ [PLACEHOLDER] Runs a read-only SQL query.
29
+ Please create sql_tools.py to implement this.
30
+ """
31
+ if "schema" in query.lower() or "describe" in query.lower():
32
+ return "report_date DATE, portfolio_id VARCHAR, sector VARCHAR, market_value_usd DOUBLE"
33
+ return "Error: 'sql_tools.py' not found. This is a placeholder."
34
 
35
+ @tool
36
+ def get_table_schema(table_name: str = "positions") -> str:
37
+ """
38
+ [PLACEHOLDER] Returns the schema for the 'positions' table.
39
+ Please create sql_tools.py to implement this.
40
+ """
41
+ return "report_date DATE, portfolio_id VARCHAR, sector VARCHAR, market_value_usd DOUBLE"
42
 
 
 
43
 
44
+ # --- Agent Tools ---
45
+ # These tools perform analysis on data *after* it has been fetched.
 
46
 
47
+ @tool
48
+ def calculate_summary_statistics_from_data(data_string: str, column: str) -> str:
49
+ """
50
+ Calculates summary statistics (mean, median, std, min, max) for a specific
51
+ 'column' from a 'data_string'.
52
+ 'data_string' should be the string output from the `run_duckdb_query` tool.
53
+ """
54
  try:
55
+ # Convert the string data back into a DataFrame
56
+ data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0)
57
+
58
+ # HACK: The string output might have an extra index column, let's find the real columns
59
+ if column not in data_df.columns:
60
+ # Try reading again, assuming first column is an unnamed index
61
+ data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0, index_col=0)
62
+ if column not in data_df.columns:
63
+ return f"Error: Column '{column}' not found in data."
64
+
65
+ stats = {
66
+ "column": column,
67
+ "mean": data_df[column].mean(),
68
+ "median": data_df[column].median(),
69
+ "std_dev": data_df[column].std(),
70
+ "min": data_df[column].min(),
71
+ "max": data_df[column].max(),
72
+ "count": data_df[column].count()
73
+ }
74
+ return str(stats)
75
  except Exception as e:
76
+ return f"Error in calculate_summary_statistics: {e}. Data input was: '{data_string[:200]}...'"
77
+
78
+ @tool
79
+ def perform_arima_forecast_from_data(data_string: str, time_column: str, value_column: str, forecast_periods: int) -> str:
80
+ """
81
+ Performs an ARIMA(1,1,1) forecast on a 'data_string'.
82
+ 'data_string': The string output from `run_duckdb_query`.
83
+ 'time_column': The name of the date/time column in the data.
84
+ 'value_column': The name of the numerical column to forecast.
85
+ 'forecast_periods': The number of periods (e.g., days) to forecast.
86
+
87
+ The data MUST be ordered by the time_column before being passed to this tool.
88
+ """
89
  try:
90
+ # Convert the string data back into a DataFrame
91
+ data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0)
92
+
93
+ # HACK: The string output might have an extra index column
94
+ if time_column not in data_df.columns:
95
+ data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0, index_col=0)
96
+ if time_column not in data_df.columns:
97
+ return f"Error: Time column '{time_column}' not found in data."
98
+
99
+ if value_column not in data_df.columns:
100
+ return f"Error: Value column '{value_column}' not found in data."
101
+
102
+ if data_df.empty:
103
+ return "Error: Query returned no data."
104
+
105
+ # Prepare data for statsmodels
106
+ data_df[time_column] = pd.to_datetime(data_df[time_column])
107
+ data_df = data_df.set_index(time_column)
108
+ data_df = data_df.asfreq('D') # Ensure daily frequency, fill gaps if any
109
+ data_df[value_column] = data_df[value_column].fillna(method='ffill')
110
+
111
+ model = sm.tsa.ARIMA(data_df[value_column], order=(1, 1, 1))
112
+ results = model.fit()
113
+ forecast = results.forecast(steps=forecast_periods)
114
+
115
+ forecast_df = pd.DataFrame({
116
+ 'date': forecast.index.strftime('%Y-%m-%d'),
117
+ 'forecasted_value': forecast.values
118
+ })
119
+
120
+ return f"Forecast successful. Last historical value was {data_df[value_column].iloc[-1]:.2f}.\nForecast:\n{forecast_df.to_string()}"
121
+
122
  except Exception as e:
123
+ return f"Error in perform_arima_forecast: {e}. Data input was: '{data_string[:200]}...'"
124
+
125
+ # --- Main Agent and UI Setup ---
126
+
127
+ # Check for the GROQ_API_KEY in Hugging Face Space Secrets
128
+ if "GROQ_API_KEY" not in os.environ:
129
+ print("GROQ_API_KEY not found in secrets!")
130
+ def missing_key_error(message, history):
131
+ return "Error: `GROQ_API_KEY` is not set in this Space's Secrets. Please add it to use the app."
132
+
133
+ gr.ChatInterface(
134
+ missing_key_error,
135
+ title="Agentic Portfolio Analyst",
136
+ description="Error: GROQ_API_KEY secret is missing."
137
+ ).launch()
138
+
139
+ else:
140
+ print("GROQ_API_KEY found. Initializing agent...")
141
+ llm = ChatGroq(model_name="llama-3.3-70b-versatile")
142
+
143
+ # 2. Collect all our tools (imported and local)
144
+ tools = [
145
+ run_duckdb_query,
146
+ get_table_schema,
147
+ calculate_summary_statistics_from_data,
148
+ perform_arima_forecast_from_data
149
+ ]
150
+
151
+ # 3. Create the Agent Prompt
152
+ system_prompt = """
153
+ You are an expert portfolio analyst. You have access to SQL tools and analysis tools.
154
+
155
+ Your logic MUST follow these steps:
156
+ 1. Use `get_table_schema` to understand the data.
157
+ 2. Use `run_duckdb_query` to fetch the raw data you need.
158
+ 3. If analysis (statistics or forecasting) is needed, take the string output
159
+ from `run_duckdb_query` and pass it *directly* to either
160
+ `calculate_summary_statistics_from_data` or `perform_arima_forecast_from_data`.
161
+
162
+ Example for forecasting:
163
+ 1. Call `run_duckdb_query("SELECT report_date, SUM(market_value_usd) AS total_value FROM positions WHERE sector = 'Tech' GROUP BY report_date ORDER BY report_date")`.
164
+ 2. Get the result string: " report_date total_value \n 2024-01-01 100000.0 \n 2024-01-02 100500.0 \n ..."
165
+ 3. Call `perform_arima_forecast_from_data(data_string=" report_date total_value \n 2024-01-01 100000.0 \n ...", time_column="report_date", value_column="total_value", forecast_periods=30)`.
166
+
167
+ Answer the user's request based on the final tool output.
168
+ """
169
+
170
+ prompt = ChatPromptTemplate.from_messages(
171
+ [
172
+ SystemMessage(content=system_prompt),
173
+ ("placeholder", "{chat_history}"),
174
+ ("human", "{input}"),
175
+ ("placeholder", "{agent_scratchpad}"),
176
+ ]
177
+ )
178
+
179
+ # 4. Create the Agent
180
+ agent = create_tool_calling_agent(llm, tools, prompt)
181
+
182
+ # 5. Create the Agent Executor
183
+ agent_executor = AgentExecutor(
184
+ agent=agent,
185
+ tools=tools,
186
+ verbose=True
187
+ )
188
+
189
+ # 6. Define the function for Gradio
190
+ def run_agent(message, history):
191
+ chat_history = []
192
+ for human_msg, ai_msg in history:
193
+ chat_history.append(("human", human_msg))
194
+ chat_history.append(("ai", ai_msg))
195
+
196
+ try:
197
+ response = agent_executor.invoke({
198
+ "input": message,
199
+ "chat_history": chat_history
200
+ })
201
+ return response["output"]
202
+ except Exception as e:
203
+ return f"An error occurred: {e}"
204
+
205
+ # 7. Launch the Gradio App
206
+ gr.ChatInterface(
207
+ run_agent,
208
+ title="Agentic Portfolio Analyst",
209
+ description="Ask me questions about your portfolio. (This app uses imported SQL tools).",
210
+ examples=[
211
+ "What is the schema of the positions table?",
212
+ "What's the total market value by sector on the last available date?",
213
+ "Give me summary statistics for the 'Tech' sector's market value from portfolio P-123. Use the 'market_value_usd' column for stats.",
214
+ "What is the 30-day forecast for the total market value of portfolio P-123? Use 'total_value' for the forecast value column."
215
+ ]
216
+ ).launch()
217
+