|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import os |
|
|
import statsmodels.api as sm |
|
|
from io import StringIO |
|
|
|
|
|
|
|
|
from langchain_groq import ChatGroq |
|
|
from langchain.agents import AgentExecutor, create_tool_calling_agent |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.tools import tool |
|
|
from langchain_core.messages import SystemMessage |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from sql_tools import run_duckdb_query, get_table_schema |
|
|
except ImportError: |
|
|
print("WARNING: Could not import from 'sql_tools.py'.") |
|
|
print("Using placeholder functions. Please create 'sql_tools.py'.") |
|
|
|
|
|
|
|
|
@tool |
|
|
def run_duckdb_query(query: str) -> str: |
|
|
""" |
|
|
[PLACEHOLDER] Runs a read-only SQL query. |
|
|
Please create sql_tools.py to implement this. |
|
|
""" |
|
|
if "schema" in query.lower() or "describe" in query.lower(): |
|
|
return "report_date DATE, portfolio_id VARCHAR, sector VARCHAR, market_value_usd DOUBLE" |
|
|
return "Error: 'sql_tools.py' not found. This is a placeholder." |
|
|
|
|
|
@tool |
|
|
def get_table_schema(table_name: str = "positions") -> str: |
|
|
""" |
|
|
[PLACEHOLDER] Returns the schema for the 'positions' table. |
|
|
Please create sql_tools.py to implement this. |
|
|
""" |
|
|
return "report_date DATE, portfolio_id VARCHAR, sector VARCHAR, market_value_usd DOUBLE" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@tool |
|
|
def calculate_summary_statistics_from_data(data_string: str, column: str) -> str: |
|
|
""" |
|
|
Calculates summary statistics (mean, median, std, min, max) for a specific |
|
|
'column' from a 'data_string'. |
|
|
'data_string' should be the string output from the `run_duckdb_query` tool. |
|
|
""" |
|
|
try: |
|
|
|
|
|
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0) |
|
|
|
|
|
|
|
|
if column not in data_df.columns: |
|
|
|
|
|
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0, index_col=0) |
|
|
if column not in data_df.columns: |
|
|
return f"Error: Column '{column}' not found in data." |
|
|
|
|
|
stats = { |
|
|
"column": column, |
|
|
"mean": data_df[column].mean(), |
|
|
"median": data_df[column].median(), |
|
|
"std_dev": data_df[column].std(), |
|
|
"min": data_df[column].min(), |
|
|
"max": data_df[column].max(), |
|
|
"count": data_df[column].count() |
|
|
} |
|
|
return str(stats) |
|
|
except Exception as e: |
|
|
return f"Error in calculate_summary_statistics: {e}. Data input was: '{data_string[:200]}...'" |
|
|
|
|
|
@tool |
|
|
def perform_arima_forecast_from_data(data_string: str, time_column: str, value_column: str, forecast_periods: int) -> str: |
|
|
""" |
|
|
Performs an ARIMA(1,1,1) forecast on a 'data_string'. |
|
|
'data_string': The string output from `run_duckdb_query`. |
|
|
'time_column': The name of the date/time column in the data. |
|
|
'value_column': The name of the numerical column to forecast. |
|
|
'forecast_periods': The number of periods (e.g., days) to forecast. |
|
|
|
|
|
The data MUST be ordered by the time_column before being passed to this tool. |
|
|
""" |
|
|
try: |
|
|
|
|
|
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0) |
|
|
|
|
|
|
|
|
if time_column not in data_df.columns: |
|
|
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0, index_col=0) |
|
|
if time_column not in data_df.columns: |
|
|
return f"Error: Time column '{time_column}' not found in data." |
|
|
|
|
|
if value_column not in data_df.columns: |
|
|
return f"Error: Value column '{value_column}' not found in data." |
|
|
|
|
|
if data_df.empty: |
|
|
return "Error: Query returned no data." |
|
|
|
|
|
|
|
|
data_df[time_column] = pd.to_datetime(data_df[time_column]) |
|
|
data_df = data_df.set_index(time_column) |
|
|
data_df = data_df.asfreq('D') |
|
|
data_df[value_column] = data_df[value_column].fillna(method='ffill') |
|
|
|
|
|
model = sm.tsa.ARIMA(data_df[value_column], order=(1, 1, 1)) |
|
|
results = model.fit() |
|
|
forecast = results.forecast(steps=forecast_periods) |
|
|
|
|
|
forecast_df = pd.DataFrame({ |
|
|
'date': forecast.index.strftime('%Y-%m-%d'), |
|
|
'forecasted_value': forecast.values |
|
|
}) |
|
|
|
|
|
return f"Forecast successful. Last historical value was {data_df[value_column].iloc[-1]:.2f}.\nForecast:\n{forecast_df.to_string()}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error in perform_arima_forecast: {e}. Data input was: '{data_string[:200]}...'" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "GROQ_API_KEY" not in os.environ: |
|
|
print("GROQ_API_KEY not found in secrets!") |
|
|
def missing_key_error(message, history): |
|
|
return "Error: `GROQ_API_KEY` is not set in this Space's Secrets. Please add it to use the app." |
|
|
|
|
|
gr.ChatInterface( |
|
|
missing_key_error, |
|
|
title="Agentic Portfolio Analyst", |
|
|
description="Error: GROQ_API_KEY secret is missing." |
|
|
).launch() |
|
|
|
|
|
else: |
|
|
print("GROQ_API_KEY found. Initializing agent...") |
|
|
llm = ChatGroq(model_name="llama-3.3-70b-versatile") |
|
|
|
|
|
|
|
|
tools = [ |
|
|
run_duckdb_query, |
|
|
get_table_schema, |
|
|
calculate_summary_statistics_from_data, |
|
|
perform_arima_forecast_from_data |
|
|
] |
|
|
|
|
|
|
|
|
system_prompt = """ |
|
|
You are an expert portfolio analyst. You have access to SQL tools and analysis tools. |
|
|
|
|
|
Your logic MUST follow these steps: |
|
|
1. Use `get_table_schema` to understand the data. |
|
|
2. Use `run_duckdb_query` to fetch the raw data you need. |
|
|
3. If analysis (statistics or forecasting) is needed, take the string output |
|
|
from `run_duckdb_query` and pass it *directly* to either |
|
|
`calculate_summary_statistics_from_data` or `perform_arima_forecast_from_data`. |
|
|
|
|
|
Example for forecasting: |
|
|
1. Call `run_duckdb_query("SELECT report_date, SUM(market_value_usd) AS total_value FROM positions WHERE sector = 'Tech' GROUP BY report_date ORDER BY report_date")`. |
|
|
2. Get the result string: " report_date total_value \n 2024-01-01 100000.0 \n 2024-01-02 100500.0 \n ..." |
|
|
3. Call `perform_arima_forecast_from_data(data_string=" report_date total_value \n 2024-01-01 100000.0 \n ...", time_column="report_date", value_column="total_value", forecast_periods=30)`. |
|
|
|
|
|
Answer the user's request based on the final tool output. |
|
|
""" |
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
SystemMessage(content=system_prompt), |
|
|
("placeholder", "{chat_history}"), |
|
|
("human", "{input}"), |
|
|
("placeholder", "{agent_scratchpad}"), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
agent = create_tool_calling_agent(llm, tools, prompt) |
|
|
|
|
|
|
|
|
agent_executor = AgentExecutor( |
|
|
agent=agent, |
|
|
tools=tools, |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
|
|
|
def run_agent(message, history): |
|
|
chat_history = [] |
|
|
for human_msg, ai_msg in history: |
|
|
chat_history.append(("human", human_msg)) |
|
|
chat_history.append(("ai", ai_msg)) |
|
|
|
|
|
try: |
|
|
response = agent_executor.invoke({ |
|
|
"input": message, |
|
|
"chat_history": chat_history |
|
|
}) |
|
|
return response["output"] |
|
|
except Exception as e: |
|
|
return f"An error occurred: {e}" |
|
|
|
|
|
|
|
|
gr.ChatInterface( |
|
|
run_agent, |
|
|
title="Agentic Portfolio Analyst", |
|
|
description="Ask me questions about your portfolio. (This app uses imported SQL tools).", |
|
|
examples=[ |
|
|
"What is the schema of the positions table?", |
|
|
"What's the total market value by sector on the last available date?", |
|
|
"Give me summary statistics for the 'Tech' sector's market value from portfolio P-123. Use the 'market_value_usd' column for stats.", |
|
|
"What is the 30-day forecast for the total market value of portfolio P-123? Use 'total_value' for the forecast value column." |
|
|
] |
|
|
).launch() |
|
|
|
|
|
|