Spaces:
Sleeping
Sleeping
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "ell-ai==0.0.14", | |
| # "marimo", | |
| # "openai==1.53.0", | |
| # "polars==1.12.0", | |
| # "altair==5.4.1", | |
| # ] | |
| # /// | |
| import marimo | |
| __generated_with = "0.9.20" | |
| app = marimo.App(width="medium") | |
| def __(mo): | |
| mo.md(r"""# Generative UI Chatbot""") | |
| return | |
| def __(mo): | |
| _default_dataset = "hf://datasets/scikit-learn/Fish/Fish.csv" | |
| dataset_input = mo.ui.text(value=_default_dataset, full_width=True) | |
| return (dataset_input,) | |
| def __(dataset_input, mo): | |
| mo.md(f""" | |
| This chatbot can answer questions about the following dataset: {dataset_input} | |
| """) | |
| return | |
| def __(dataset_input, mo, pl): | |
| # Grab a dataset | |
| try: | |
| df = pl.read_csv(dataset_input.value) | |
| mo.output.replace( | |
| mo.md(f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns.") | |
| ) | |
| except Exception as e: | |
| df = pl.DataFrame() | |
| mo.output.replace( | |
| mo.md(f"""**Error loading dataset**:\n\n{e}""").callout(kind="danger") | |
| ) | |
| return (df,) | |
| def __(): | |
| import os | |
| import marimo as mo | |
| import polars as pl | |
| return mo, os, pl | |
| def __(mo, os): | |
| api_key_input = mo.ui.text( | |
| label="OpenAI API Key", | |
| kind="password", | |
| value=os.environ.get("OPENAI_API_KEY") or "", | |
| ) | |
| return (api_key_input,) | |
| def __(api_key_input): | |
| api_key_input | |
| return | |
| def __(api_key_input, mo): | |
| from openai import Client | |
| mo.stop(not api_key_input.value, mo.md("_Missing API key_")) | |
| client = Client(api_key=api_key_input.value) | |
| return Client, client | |
| def __(df, mo): | |
| import ell | |
| def chart_data(x_encoding: str, y_encoding: str, color: str): | |
| """Generate an altair chart""" | |
| import altair as alt | |
| return ( | |
| alt.Chart(df) | |
| .mark_circle() | |
| .encode(x=x_encoding, y=y_encoding, color=color) | |
| .properties(width=500) | |
| ) | |
| def filter_dataset(sql_query: str): | |
| """ | |
| Filter a polars dataframe using SQL. Please only use fields from the schema. | |
| When referring to the table in SQL, call it 'data'. | |
| """ | |
| filtered = df.sql(sql_query, table_name="data") | |
| return mo.ui.table( | |
| filtered, | |
| label=f"```sql\n{sql_query}\n```", | |
| selection=None, | |
| show_column_summaries=False, | |
| ) | |
| return chart_data, ell, filter_dataset | |
| def __(chart_data, client, df, ell, filter_dataset, mo): | |
| def analyze_dataset(prompt: str) -> str: | |
| """You are a data scientist that can analyze a dataset""" | |
| return f"I have a dataset with schema: {df.schema}. \n{prompt}" | |
| def my_model(messages): | |
| response = analyze_dataset(messages) | |
| if response.tool_calls: | |
| return response.tool_calls[0]() | |
| return response.text | |
| mo.ui.chat( | |
| my_model, | |
| prompts=[ | |
| "Can you chart two columns of your choosing?", | |
| "Can you find the min, max of all numeric fields?", | |
| "What is the sum of {{column}}?", | |
| ], | |
| ) | |
| return analyze_dataset, my_model | |
| if __name__ == "__main__": | |
| app.run() | |