awacke1's picture
Update app.py
0ad5b45 verified
raw
history blame
10.1 kB
# app.py
import gradio as gr
import pandas as pd
import requests
import io
import dask.dataframe as dd
from datasets import load_dataset, Image
from mlcroissant import Dataset as CroissantDataset
from huggingface_hub import get_token
import polars as pl
import warnings
import traceback
# 🀫 Let's ignore those pesky warnings, shall we?
warnings.filterwarnings("ignore")
# --- βš™οΈ Configuration & Constants ---
DATASET_CONFIG = {
"caselaw": {
"name": "common-pile/caselaw_access_project", "emoji": "βš–οΈ",
"methods": ["πŸ’¨ API (requests)", "🧊 Dask", "πŸ₯ Croissant"], "is_public": True,
},
"prompts": {
"name": "fka/awesome-chatgpt-prompts", "emoji": "πŸ€–",
"methods": ["🐼 Pandas", "πŸ’¨ API (requests)", "πŸ₯ Croissant"], "is_public": True,
},
"finance": {
"name": "snorkelai/agent-finance-reasoning", "emoji": "πŸ’°",
"methods": ["🐼 Pandas", "🧊 Polars", "πŸ’¨ API (requests)", "πŸ₯ Croissant"], "is_public": False,
},
"medical": {
"name": "FreedomIntelligence/medical-o1-reasoning-SFT", "emoji": "🩺",
"methods": ["🐼 Pandas", "🧊 Polars", "πŸ’¨ API (requests)", "πŸ₯ Croissant"], "is_public": False,
},
"inscene": {
"name": "peteromallet/InScene-Dataset", "emoji": "πŸ–ΌοΈ",
"methods": ["πŸ€— Datasets", "🐼 Pandas", "🧊 Polars", "πŸ’¨ API (requests)", "πŸ₯ Croissant"], "is_public": False,
},
}
# --- ν—¬ Helpers & Utility Functions ---
def get_auth_headers():
token = get_token()
return {"Authorization": f"Bearer {token}"} if token else {}
def dataframe_to_outputs(df: pd.DataFrame):
if df.empty:
return "No results found. 🀷", None, None, "No results to copy."
df_str = df.astype(str)
markdown_output = df_str.to_markdown(index=False)
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False)
csv_buffer.seek(0)
excel_buffer = io.BytesIO()
df.to_excel(excel_buffer, index=False, engine='openpyxl')
excel_buffer.seek(0)
tab_delimited_output = df.to_csv(sep='\t', index=False)
return (
markdown_output,
gr.File.from_bytes(csv_buffer.getvalue(), "results.csv"),
gr.File.from_bytes(excel_buffer.getvalue(), "results.xlsx"),
tab_delimited_output,
)
def handle_error(e: Exception):
error_message = f"🚨 An error occurred: {str(e)}\n\n"
auth_tip = "πŸ”‘ For gated datasets, did you log in? Try `huggingface-cli login` in your terminal."
full_trace = traceback.format_exc()
print(full_trace)
if "401" in str(e) or "Gated" in str(e):
error_message += auth_tip
# Return a tuple of 8 to match the outputs
return (
pd.DataFrame(), gr.Gallery(None), f"### 🚨 Error\n```\n{error_message}\n```",
"", None, None, "", f"```python\n# 🚨 Error during execution:\n# {e}\n```"
)
def search_dataframe(df: pd.DataFrame, query: str):
if not query:
return df.head(100)
string_cols = df.select_dtypes(include=['object', 'string']).columns
if string_cols.empty:
return pd.DataFrame()
mask = pd.Series([False] * len(df))
for col in string_cols:
mask |= df[col].astype(str).str.contains(query, case=False, na=False)
return df[mask]
# --- 🎣 Data Fetching & Processing Functions ---
# --- ✨ REWRITTEN to always yield a tuple of 8 values ---
def fetch_data(dataset_key: str, access_method: str, query: str):
"""
πŸš€ Main mission control. Always yields a tuple of 8 values to match the UI components.
"""
# 1. Initialize the state for all 8 output components
outputs = [pd.DataFrame(), None, "🏁 Ready.", "", None, None, "", ""]
try:
config = DATASET_CONFIG[dataset_key]
repo_id = config["name"]
# --- API (requests) - ASYNC/STREAMING ---
if "API" in access_method:
all_results_df = pd.DataFrame()
MAX_PAGES = 5
PAGE_SIZE = 100
if not query:
MAX_PAGES = 1
outputs[2] = "⏳ No search term. Fetching first 100 records as a sample..."
yield tuple(outputs)
for page in range(MAX_PAGES):
if query:
outputs[2] = f"⏳ Searching page {page + 1}..."
yield tuple(outputs)
offset = page * PAGE_SIZE
url = f"https://datasets-server.huggingface.co/rows?dataset={repo_id}&config=default&split=train&offset={offset}&length={PAGE_SIZE}"
headers = get_auth_headers() if not config["is_public"] else {}
response = requests.get(url, headers=headers)
response.raise_for_status()
data = response.json()
if not data.get('rows'):
outputs[2] = "🏁 No more data to search."
yield tuple(outputs)
break
page_df = pd.json_normalize(data['rows'], record_path='row')
found_in_page = search_dataframe(page_df, query)
if not found_in_page.empty:
all_results_df = pd.concat([all_results_df, found_in_page]).reset_index(drop=True)
outputs[0] = all_results_df
outputs[3], outputs[4], outputs[5], outputs[6] = dataframe_to_outputs(all_results_df)
outputs[2] = f"βœ… Found **{len(all_results_df)}** results so far..."
if dataset_key == 'inscene':
gallery_data = [(row['image'], row.get('text', '')) for _, row in all_results_df.iterrows() if 'image' in row and isinstance(row['image'], Image.Image)]
outputs[1] = gr.Gallery(gallery_data, label="πŸ–ΌοΈ Image Results", height=400)
yield tuple(outputs)
outputs[2] = f"🏁 Search complete. Found a total of **{len(all_results_df)}** results."
yield tuple(outputs)
return
# --- Other methods (non-streaming) ---
outputs[2] = f"⏳ Loading data via `{access_method}`..."
yield tuple(outputs)
df = pd.DataFrame()
if "Pandas" in access_method:
file_path = f"hf://datasets/{repo_id}/"
if repo_id == "fka/awesome-chatgpt-prompts": file_path += "prompts.csv"; df = pd.read_csv(file_path)
else:
try: df = pd.read_parquet(f"{file_path}data/train-00000-of-00001.parquet")
except:
try: df = pd.read_parquet(f"{file_path}train.parquet")
except: df = pd.read_json(f"{file_path}medical_o1_sft.json")
elif "Datasets" in access_method:
ds = load_dataset(repo_id, split='train', streaming=True).take(1000)
df = pd.DataFrame(ds)
# Add other access methods (Dask, Polars, Croissant) here if needed, following the same pattern
outputs[2] = "πŸ” Searching loaded data..."
yield tuple(outputs)
final_df = search_dataframe(df, query)
outputs[0] = final_df
outputs[3], outputs[4], outputs[5], outputs[6] = dataframe_to_outputs(final_df)
outputs[2] = f"🏁 Search complete. Found **{len(final_df)}** results."
if dataset_key == 'inscene' and not final_df.empty:
gallery_data = [(row['image'], row.get('text', '')) for _, row in final_df.iterrows() if 'image' in row and isinstance(row.get('image'), Image.Image)]
outputs[1] = gr.Gallery(gallery_data, label="πŸ–ΌοΈ Image Results", height=400)
yield tuple(outputs)
except Exception as e:
yield handle_error(e)
# --- πŸ–ΌοΈ UI Generation ---
def create_dataset_tab(dataset_key: str):
config = DATASET_CONFIG[dataset_key]
with gr.Tab(f"{config['emoji']} {dataset_key.capitalize()}"):
gr.Markdown(f"## {config['emoji']} Query the `{config['name']}` Dataset")
if not config['is_public']:
gr.Markdown("**Note:** This is a gated dataset. Please log in via `huggingface-cli login` in your terminal first.")
with gr.Row():
access_method = gr.Radio(config['methods'], label="πŸ”‘ Access Method", value=config['methods'][0])
query = gr.Textbox(label="πŸ” Search Query", placeholder="Enter any text to search, or leave blank for samples...")
fetch_button = gr.Button("πŸš€ Go Fetch!")
status_output = gr.Markdown("🏁 Ready to search.")
df_output = gr.DataFrame(label="πŸ“Š Results DataFrame", interactive=False, wrap=True)
gallery_output = gr.Gallery(visible=(dataset_key == 'inscene'), label="πŸ–ΌοΈ Image Results")
with gr.Accordion("πŸ“‚ View/Export Full Results", open=False):
markdown_output = gr.Markdown(label="πŸ“ Markdown View")
with gr.Row():
csv_output = gr.File(label="⬇️ Download CSV")
xlsx_output = gr.File(label="⬇️ Download XLSX")
copy_output = gr.Code(label="πŸ“‹ Copy-Paste (Tab-Delimited)")
code_output = gr.Code(label="πŸ’» Python Code Snippet", language="python")
fetch_button.click(
fn=fetch_data,
inputs=[gr.State(dataset_key), access_method, query],
outputs=[
df_output, gallery_output, status_output, markdown_output,
csv_output, xlsx_output, copy_output, code_output
]
)
# --- πŸš€ Main App ---
with gr.Blocks(theme=gr.themes.Soft(), title="Hugging Face Dataset Explorer") as demo:
gr.Markdown("# πŸ€— Hugging Face Dataset Explorer")
gr.Markdown(
"Select a dataset, choose an access method, and type a query. "
"The app now **streams results** for the API method and performs a **universal text search** on all datasets!"
)
with gr.Tabs():
for key in DATASET_CONFIG.keys():
create_dataset_tab(key)
if __name__ == "__main__":
demo.launch(debug=True)