File size: 8,951 Bytes
e81b80e 9b3d9a0 af53f4b 9b3d9a0 68c51bb 9b3d9a0 af53f4b 9b3d9a0 da25b2a 0ffc27e 9b3d9a0 af53f4b 9b3d9a0 e81b80e 9b3d9a0 e81b80e 9b3d9a0 da25b2a 9b3d9a0 da25b2a 9b3d9a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import gradio as gr
import pandas as pd
import numpy as np
import os
import statsmodels.api as sm
from io import StringIO
# --- LangChain Imports ---
from langchain_groq import ChatGroq
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
# --- ASSUMPTION ---
# Assuming you have a file named 'sql_tools.py' in the same directory
# with your pre-built and decorated @tool functions.
try:
from sql_tools import run_duckdb_query, get_table_schema
except ImportError:
print("WARNING: Could not import from 'sql_tools.py'.")
print("Using placeholder functions. Please create 'sql_tools.py'.")
# Create placeholder tools if the file is missing, so the app can start
@tool
def run_duckdb_query(query: str) -> str:
"""
[PLACEHOLDER] Runs a read-only SQL query.
Please create sql_tools.py to implement this.
"""
if "schema" in query.lower() or "describe" in query.lower():
return "report_date DATE, portfolio_id VARCHAR, sector VARCHAR, market_value_usd DOUBLE"
return "Error: 'sql_tools.py' not found. This is a placeholder."
@tool
def get_table_schema(table_name: str = "positions") -> str:
"""
[PLACEHOLDER] Returns the schema for the 'positions' table.
Please create sql_tools.py to implement this.
"""
return "report_date DATE, portfolio_id VARCHAR, sector VARCHAR, market_value_usd DOUBLE"
# --- Agent Tools ---
# These tools perform analysis on data *after* it has been fetched.
@tool
def calculate_summary_statistics_from_data(data_string: str, column: str) -> str:
"""
Calculates summary statistics (mean, median, std, min, max) for a specific
'column' from a 'data_string'.
'data_string' should be the string output from the `run_duckdb_query` tool.
"""
try:
# Convert the string data back into a DataFrame
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0)
# HACK: The string output might have an extra index column, let's find the real columns
if column not in data_df.columns:
# Try reading again, assuming first column is an unnamed index
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0, index_col=0)
if column not in data_df.columns:
return f"Error: Column '{column}' not found in data."
stats = {
"column": column,
"mean": data_df[column].mean(),
"median": data_df[column].median(),
"std_dev": data_df[column].std(),
"min": data_df[column].min(),
"max": data_df[column].max(),
"count": data_df[column].count()
}
return str(stats)
except Exception as e:
return f"Error in calculate_summary_statistics: {e}. Data input was: '{data_string[:200]}...'"
@tool
def perform_arima_forecast_from_data(data_string: str, time_column: str, value_column: str, forecast_periods: int) -> str:
"""
Performs an ARIMA(1,1,1) forecast on a 'data_string'.
'data_string': The string output from `run_duckdb_query`.
'time_column': The name of the date/time column in the data.
'value_column': The name of the numerical column to forecast.
'forecast_periods': The number of periods (e.g., days) to forecast.
The data MUST be ordered by the time_column before being passed to this tool.
"""
try:
# Convert the string data back into a DataFrame
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0)
# HACK: The string output might have an extra index column
if time_column not in data_df.columns:
data_df = pd.read_csv(StringIO(data_string.strip()), delim_whitespace=True, header=0, index_col=0)
if time_column not in data_df.columns:
return f"Error: Time column '{time_column}' not found in data."
if value_column not in data_df.columns:
return f"Error: Value column '{value_column}' not found in data."
if data_df.empty:
return "Error: Query returned no data."
# Prepare data for statsmodels
data_df[time_column] = pd.to_datetime(data_df[time_column])
data_df = data_df.set_index(time_column)
data_df = data_df.asfreq('D') # Ensure daily frequency, fill gaps if any
data_df[value_column] = data_df[value_column].fillna(method='ffill')
model = sm.tsa.ARIMA(data_df[value_column], order=(1, 1, 1))
results = model.fit()
forecast = results.forecast(steps=forecast_periods)
forecast_df = pd.DataFrame({
'date': forecast.index.strftime('%Y-%m-%d'),
'forecasted_value': forecast.values
})
return f"Forecast successful. Last historical value was {data_df[value_column].iloc[-1]:.2f}.\nForecast:\n{forecast_df.to_string()}"
except Exception as e:
return f"Error in perform_arima_forecast: {e}. Data input was: '{data_string[:200]}...'"
# --- Main Agent and UI Setup ---
# Check for the GROQ_API_KEY in Hugging Face Space Secrets
if "GROQ_API_KEY" not in os.environ:
print("GROQ_API_KEY not found in secrets!")
def missing_key_error(message, history):
return "Error: `GROQ_API_KEY` is not set in this Space's Secrets. Please add it to use the app."
gr.ChatInterface(
missing_key_error,
title="Agentic Portfolio Analyst",
description="Error: GROQ_API_KEY secret is missing."
).launch()
else:
print("GROQ_API_KEY found. Initializing agent...")
llm = ChatGroq(model_name="llama-3.3-70b-versatile")
# 2. Collect all our tools (imported and local)
tools = [
run_duckdb_query,
get_table_schema,
calculate_summary_statistics_from_data,
perform_arima_forecast_from_data
]
# 3. Create the Agent Prompt
system_prompt = """
You are an expert portfolio analyst. You have access to SQL tools and analysis tools.
Your logic MUST follow these steps:
1. Use `get_table_schema` to understand the data.
2. Use `run_duckdb_query` to fetch the raw data you need.
3. If analysis (statistics or forecasting) is needed, take the string output
from `run_duckdb_query` and pass it *directly* to either
`calculate_summary_statistics_from_data` or `perform_arima_forecast_from_data`.
Example for forecasting:
1. Call `run_duckdb_query("SELECT report_date, SUM(market_value_usd) AS total_value FROM positions WHERE sector = 'Tech' GROUP BY report_date ORDER BY report_date")`.
2. Get the result string: " report_date total_value \n 2024-01-01 100000.0 \n 2024-01-02 100500.0 \n ..."
3. Call `perform_arima_forecast_from_data(data_string=" report_date total_value \n 2024-01-01 100000.0 \n ...", time_column="report_date", value_column="total_value", forecast_periods=30)`.
Answer the user's request based on the final tool output.
"""
prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=system_prompt),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
# 4. Create the Agent
agent = create_tool_calling_agent(llm, tools, prompt)
# 5. Create the Agent Executor
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True
)
# 6. Define the function for Gradio
def run_agent(message, history):
chat_history = []
for human_msg, ai_msg in history:
chat_history.append(("human", human_msg))
chat_history.append(("ai", ai_msg))
try:
response = agent_executor.invoke({
"input": message,
"chat_history": chat_history
})
return response["output"]
except Exception as e:
return f"An error occurred: {e}"
# 7. Launch the Gradio App
gr.ChatInterface(
run_agent,
title="Agentic Portfolio Analyst",
description="Ask me questions about your portfolio. (This app uses imported SQL tools).",
examples=[
"What is the schema of the positions table?",
"What's the total market value by sector on the last available date?",
"Give me summary statistics for the 'Tech' sector's market value from portfolio P-123. Use the 'market_value_usd' column for stats.",
"What is the 30-day forecast for the total market value of portfolio P-123? Use 'total_value' for the forecast value column."
]
).launch()
|