Togmal-demo / app.py
HeTalksInMaths
Fix database info display when db is None
41ec4e2
raw
history blame
21.8 kB
#!/usr/bin/env python3
"""
ToGMAL Difficulty Assessment Demo
=================================
Gradio demo for the vector database-based prompt difficulty assessment.
Shows real-time difficulty scores and recommendations.
"""
import gradio as gr
import json
from pathlib import Path
from benchmark_vector_db import BenchmarkVectorDB
import logging
import os
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize the vector database with error handling
db_path = Path("./data/benchmark_vector_db")
db = None
try:
logger.info("Initializing BenchmarkVectorDB...")
db = BenchmarkVectorDB(
db_path=db_path,
embedding_model="all-MiniLM-L6-v2"
)
logger.info("βœ“ BenchmarkVectorDB initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize BenchmarkVectorDB: {e}")
logger.error("This might be due to network issues downloading the embedding model.")
logger.error("The app will attempt to initialize on first use.")
# Will try again on first query
db = None
# Build database if not exists (first launch on Hugging Face)
# Start with a manageable size to avoid build timeout
if db is not None:
current_count = db.collection.count()
else:
logger.warning("Database not initialized - will retry on first query")
current_count = 0
if db is not None and current_count == 0:
logger.info("Database is empty - building database...")
logger.info("Building 5K questions to stay within build time limits.")
try:
from datasets import load_dataset
from benchmark_vector_db import BenchmarkQuestion
# Load MMLU-Pro test split (sample 5K for fast build)
logger.info("Loading MMLU-Pro test split (5K sample)...")
test_dataset = load_dataset("TIGER-Lab/MMLU-Pro", split="test")
logger.info(f" Dataset has {len(test_dataset)} questions total")
# Sample 5000 questions for fast initial build
import random
total_questions = len(test_dataset)
if total_questions > 5000:
indices = random.sample(range(total_questions), 5000)
test_dataset = test_dataset.select(indices)
logger.info(f" Sampled 5000 questions for initial build")
all_questions = []
# Process questions
for idx, item in enumerate(test_dataset):
question = BenchmarkQuestion(
question_id=f"mmlu_pro_test_{idx}",
source_benchmark="MMLU_Pro",
domain=item.get('category', 'unknown').lower(),
question_text=item['question'],
correct_answer=item['answer'],
choices=item.get('options', []),
success_rate=0.45,
difficulty_score=0.55,
difficulty_label="Hard",
num_models_tested=0
)
all_questions.append(question)
logger.info(f"Indexing {len(all_questions)} questions...")
# Index in batches of 1000
batch_size = 1000
for i in range(0, len(all_questions), batch_size):
batch = all_questions[i:i + batch_size]
batch_num = i // batch_size + 1
total_batches = (len(all_questions) + batch_size - 1) // batch_size
logger.info(f" Batch {batch_num}/{total_batches}...")
db.index_questions(batch)
logger.info(f"βœ“ Database build complete! Indexed {len(all_questions)} questions")
logger.info("Note: This is a 5K subset. Full 26K database available locally.")
except Exception as e:
logger.error(f"Failed to build database: {e}")
logger.info("Falling back to minimal build...")
db.build_database(
load_gpqa=False,
load_mmlu_pro=True,
load_math=False,
max_samples_per_dataset=1000
)
else:
logger.info(f"βœ“ Loaded existing database with {current_count:,} questions")
def analyze_prompt(prompt: str, k: int = 5) -> str:
"""Analyze a prompt and return difficulty assessment."""
global db
if not prompt.strip():
return "Please enter a prompt to analyze."
# Retry DB initialization if it failed before
if db is None:
try:
logger.info("Retrying database initialization...")
db = BenchmarkVectorDB(
db_path=db_path,
embedding_model="all-MiniLM-L6-v2"
)
logger.info("βœ“ Database initialized successfully on retry")
except Exception as e:
return f"""### ❌ Database Initialization Error
The vector database could not be initialized due to a network error downloading the embedding model.
**Error:** {str(e)}
**This is a temporary HuggingFace Spaces issue.** Please:
1. Wait a few minutes for the model to download
2. Try refreshing the page
3. Contact support if the issue persists
The embedding model `sentence-transformers/all-MiniLM-L6-v2` is being downloaded from HuggingFace Hub.
"""
try:
result = db.query_similar_questions(prompt, k=k)
# Format results
output = []
output.append(f"## 🎯 Difficulty Assessment\n")
output.append(f"**Risk Level**: {result['risk_level']}")
output.append(f"**Success Rate**: {result['weighted_success_rate']:.1%}")
output.append(f"**Avg Similarity**: {result['avg_similarity']:.3f}")
output.append("")
output.append(f"**Recommendation**: {result['recommendation']}")
output.append("")
output.append(f"## πŸ” Similar Benchmark Questions\n")
for i, q in enumerate(result['similar_questions'], 1):
output.append(f"{i}. **{q['question_text'][:100]}...**")
output.append(f" - Source: {q['source']} ({q['domain']})")
output.append(f" - Success Rate: {q['success_rate']:.1%}")
output.append(f" - Similarity: {q['similarity']:.3f}")
output.append("")
total_questions = db.collection.count()
output.append(f"*Analyzed using {k} most similar questions from {total_questions:,} benchmark questions*")
return "\n".join(output)
except Exception as e:
return f"Error analyzing prompt: {str(e)}"
def expand_database(batch_size: int = 5000) -> str:
"""Expand the database by adding another batch of questions from multiple sources."""
try:
from datasets import load_dataset
from benchmark_vector_db import BenchmarkQuestion
import random
current_count = db.collection.count()
# Load from ALL available sources to reach 32K+
# We'll build a pool of all questions and track which ones we've indexed
all_questions_pool = []
logger.info("Loading all available benchmark datasets...")
# Source 1: MMLU-Pro test split (12,032 questions)
try:
logger.info(" Loading MMLU-Pro test...")
mmlu_pro_test = load_dataset("TIGER-Lab/MMLU-Pro", split="test")
for idx, item in enumerate(mmlu_pro_test):
all_questions_pool.append({
'id': f"mmlu_pro_test_{idx}",
'source': 'MMLU_Pro',
'domain': item.get('category', 'unknown').lower(),
'question': item['question'],
'answer': item['answer'],
'choices': item.get('options', []),
'success_rate': 0.45
})
logger.info(f" Added {len(mmlu_pro_test)} MMLU-Pro test questions")
except Exception as e:
logger.warning(f" Could not load MMLU-Pro test: {e}")
# Source 2: MMLU-Pro validation split (70 questions)
try:
logger.info(" Loading MMLU-Pro validation...")
mmlu_pro_val = load_dataset("TIGER-Lab/MMLU-Pro", split="validation")
for idx, item in enumerate(mmlu_pro_val):
all_questions_pool.append({
'id': f"mmlu_pro_val_{idx}",
'source': 'MMLU_Pro',
'domain': item.get('category', 'unknown').lower(),
'question': item['question'],
'answer': item['answer'],
'choices': item.get('options', []),
'success_rate': 0.45
})
logger.info(f" Added {len(mmlu_pro_val)} MMLU-Pro validation questions")
except Exception as e:
logger.warning(f" Could not load MMLU-Pro validation: {e}")
# Source 3: MMLU (original - 14,042 questions for cross-domain coverage)
try:
logger.info(" Loading MMLU (original)...")
# MMLU has multiple subjects, we'll load the test split
# Using the 'all' configuration to get all subjects
mmlu_dataset = load_dataset("cais/mmlu", "all", split="test")
for idx, item in enumerate(mmlu_dataset):
all_questions_pool.append({
'id': f"mmlu_{idx}",
'source': 'MMLU',
'domain': item.get('subject', 'cross_domain').lower(),
'question': item['question'],
'answer': str(item['answer']),
'choices': item.get('choices', []),
'success_rate': 0.65 # MMLU is easier than MMLU-Pro
})
logger.info(f" Added {len(mmlu_dataset)} MMLU questions")
except Exception as e:
logger.warning(f" Could not load MMLU: {e}")
# Source 4: ARC-Challenge - Science reasoning
try:
logger.info(" Loading ARC-Challenge...")
arc_dataset = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
for idx, item in enumerate(arc_dataset):
all_questions_pool.append({
'id': f"arc_challenge_{idx}",
'source': 'ARC-Challenge',
'domain': 'science',
'question': item['question'],
'answer': item['answerKey'],
'choices': item['choices']['text'] if 'choices' in item else [],
'success_rate': 0.50
})
logger.info(f" Added {len(arc_dataset)} ARC-Challenge questions")
except Exception as e:
logger.warning(f" Could not load ARC-Challenge: {e}")
# Source 5: HellaSwag - Commonsense NLI (sample 2K from 10K)
try:
logger.info(" Loading HellaSwag...")
hellaswag_dataset = load_dataset("Rowan/hellaswag", split="validation")
# Sample to 2000 to manage size
if len(hellaswag_dataset) > 2000:
indices = random.sample(range(len(hellaswag_dataset)), 2000)
hellaswag_dataset = hellaswag_dataset.select(indices)
for idx, item in enumerate(hellaswag_dataset):
all_questions_pool.append({
'id': f"hellaswag_{idx}",
'source': 'HellaSwag',
'domain': 'commonsense',
'question': item['ctx'],
'answer': str(item['label']),
'choices': item['endings'] if 'endings' in item else [],
'success_rate': 0.65
})
logger.info(f" Added {len(hellaswag_dataset)} HellaSwag questions")
except Exception as e:
logger.warning(f" Could not load HellaSwag: {e}")
# Source 6: GSM8K - Math word problems
try:
logger.info(" Loading GSM8K...")
gsm8k_dataset = load_dataset("openai/gsm8k", "main", split="test")
for idx, item in enumerate(gsm8k_dataset):
all_questions_pool.append({
'id': f"gsm8k_{idx}",
'source': 'GSM8K',
'domain': 'math_word_problems',
'question': item['question'],
'answer': item['answer'],
'choices': None,
'success_rate': 0.55
})
logger.info(f" Added {len(gsm8k_dataset)} GSM8K questions")
except Exception as e:
logger.warning(f" Could not load GSM8K: {e}")
# Source 7: TruthfulQA - Truthfulness detection
try:
logger.info(" Loading TruthfulQA...")
truthfulqa_dataset = load_dataset("truthful_qa", "generation", split="validation")
for idx, item in enumerate(truthfulqa_dataset):
all_questions_pool.append({
'id': f"truthfulqa_{idx}",
'source': 'TruthfulQA',
'domain': 'truthfulness',
'question': item['question'],
'answer': item['best_answer'],
'choices': None,
'success_rate': 0.35
})
logger.info(f" Added {len(truthfulqa_dataset)} TruthfulQA questions")
except Exception as e:
logger.warning(f" Could not load TruthfulQA: {e}")
# Source 8: Winogrande - Commonsense reasoning
try:
logger.info(" Loading Winogrande...")
winogrande_dataset = load_dataset("winogrande", "winogrande_xl", split="validation")
for idx, item in enumerate(winogrande_dataset):
all_questions_pool.append({
'id': f"winogrande_{idx}",
'source': 'Winogrande',
'domain': 'commonsense_reasoning',
'question': item['sentence'],
'answer': item['answer'],
'choices': [item['option1'], item['option2']],
'success_rate': 0.70
})
logger.info(f" Added {len(winogrande_dataset)} Winogrande questions")
except Exception as e:
logger.warning(f" Could not load Winogrande: {e}")
total_available = len(all_questions_pool)
logger.info(f"Total questions available: {total_available:,}")
if current_count >= total_available:
return f"βœ… Database is complete! All {total_available:,} questions already indexed.\n\nπŸ“Š **20 domains** across **7 benchmark sources**!"
# Get next batch (skip ones we've already indexed)
start_idx = current_count
end_idx = min(start_idx + batch_size, total_available)
batch_data = all_questions_pool[start_idx:end_idx]
# Convert to BenchmarkQuestion objects
batch_questions = []
for q_data in batch_data:
question = BenchmarkQuestion(
question_id=q_data['id'],
source_benchmark=q_data['source'],
domain=q_data['domain'],
question_text=q_data['question'],
correct_answer=q_data['answer'],
choices=q_data.get('choices'),
success_rate=q_data['success_rate'],
difficulty_score=1.0 - q_data['success_rate'],
difficulty_label="Hard" if q_data['success_rate'] < 0.5 else "Moderate",
num_models_tested=0
)
batch_questions.append(question)
logger.info(f"Indexing {len(batch_questions)} new questions...")
db.index_questions(batch_questions)
new_count = db.collection.count()
still_remaining = total_available - new_count
result = f"βœ… Successfully added {len(batch_questions)} questions!\n\n"
result += f"**Database Stats:**\n"
result += f"- Total Questions: {new_count:,}\n"
result += f"- Just Added: {len(batch_questions)}\n"
result += f"- Total Available: {total_available:,}\n"
result += f"- Remaining: {still_remaining:,}\n\n"
if still_remaining > 0:
result += f"πŸ’‘ Click 'Expand Database' again to add up to {min(batch_size, still_remaining):,} more questions.\n"
result += f"πŸ“Š Progress: {(new_count/total_available*100):.1f}% complete"
else:
result += f"πŸŽ‰ Database is now complete with all {total_available:,} questions!\n\n"
result += f"πŸ“š **Sources:** MMLU, MMLU-Pro, ARC-Challenge, HellaSwag, GSM8K, TruthfulQA, Winogrande\n"
result += f"🌐 **Domains:** 20+ including science, math, truthfulness, commonsense, and more!"
return result
except Exception as e:
logger.error(f"Expansion failed: {e}")
return f"❌ Error expanding database: {str(e)}"
def get_database_info() -> str:
"""Get current database statistics."""
global db
# Check if database is initialized
if db is None:
return """### ⚠️ Database Not Initialized
**Status:** Waiting for initialization
The vector database is not yet ready. This can happen when:
1. The embedding model is still downloading
2. Network issues prevented model download during startup
**To initialize:**
1. Try entering a prompt below and clicking "Analyze Difficulty"
2. This will trigger the database initialization
3. Wait a few moments for the model to download
4. Then refresh this page
Or wait a few minutes and click the "πŸ”„ Refresh Stats" button.
"""
try:
current_count = db.collection.count()
# Total available across all sources
# MMLU: ~14,042 + MMLU-Pro: 12,102 + ARC: 1,172 + HellaSwag: 2,000
# + GSM8K: 1,319 + TruthfulQA: 817 + Winogrande: 1,267 = ~32,719 total
total_available = 32719
remaining = total_available - current_count
progress_pct = (current_count / total_available * 100) if total_available > 0 else 0
info = f"### πŸ“Š Database Status\n\n"
info += f"**Current Size:** {current_count:,} questions\n"
info += f"**Total Available:** {total_available:,} questions\n"
info += f"**Progress:** {progress_pct:.1f}% complete\n"
info += f"**Remaining:** {max(0, remaining):,} questions\n\n"
if remaining > 0:
info += f"πŸ’‘ Click 'Expand Database' to add 5,000 more questions (~2-3 min per batch)\n\n"
clicks_needed = (remaining + 4999) // 5000 # Round up
info += f"πŸ“ˆ ~{clicks_needed} more clicks to reach full 32K+ dataset"
else:
info += f"πŸŽ‰ Database is complete with all available questions!\n\n"
info += f"**Sources:** MMLU, MMLU-Pro, ARC-Challenge, HellaSwag, GSM8K, TruthfulQA, Winogrande\n"
info += f"**Domains:** 20+ including truthfulness, commonsense, math word problems, science, and more!"
return info
except Exception as e:
return f"Error getting database info: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="ToGMAL Prompt Difficulty Analyzer") as demo:
gr.Markdown("# 🧠 ToGMAL Prompt Difficulty Analyzer")
gr.Markdown("Enter any prompt to see how difficult it is for current LLMs based on real benchmark data.")
# Database expansion section
with gr.Accordion("πŸ“Š Database Management", open=False):
db_info = gr.Markdown(get_database_info())
with gr.Row():
expand_btn = gr.Button("πŸš€ Expand Database (+5K questions)", variant="secondary")
refresh_btn = gr.Button("πŸ”„ Refresh Stats", variant="secondary")
expand_output = gr.Markdown()
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Enter your prompt",
placeholder="e.g., Calculate the quantum correction to the partition function...",
lines=3
)
k_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of similar questions to show"
)
submit_btn = gr.Button("Analyze Difficulty", variant="primary")
with gr.Column():
result_output = gr.Markdown(label="Analysis Results")
# Examples
gr.Examples(
examples=[
"Calculate the quantum correction to the partition function for a 3D harmonic oscillator",
"Prove that there are infinitely many prime numbers",
"Diagnose a patient with acute chest pain and shortness of breath",
"Explain the legal doctrine of precedent in common law systems",
"Implement a binary search tree with insert and search operations",
"What is 2 + 2?",
"What is the capital of France?"
],
inputs=prompt_input
)
# Event handling
submit_btn.click(
fn=analyze_prompt,
inputs=[prompt_input, k_slider],
outputs=result_output
)
prompt_input.submit(
fn=analyze_prompt,
inputs=[prompt_input, k_slider],
outputs=result_output
)
expand_btn.click(
fn=expand_database,
inputs=[],
outputs=expand_output
)
refresh_btn.click(
fn=get_database_info,
inputs=[],
outputs=db_info
)
if __name__ == "__main__":
# HuggingFace Spaces: Use default port (7860) and auto-share
# Port is auto-assigned by HF Spaces infrastructure
import os
port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
demo.launch(server_name="0.0.0.0", server_port=port)