Spaces:
Sleeping
Sleeping
HeTalksInMaths
commited on
Commit
·
78682b6
1
Parent(s):
d97cc93
Update to 26K question database with batch indexing
Browse files- Load MMLU-Pro validation + test splits (12K questions)
- Batch indexing for stability (1K per batch)
- Fixed stats sampling to use all questions
- Dynamic question count display
- ~10-15 min first launch, then instant
- app.py +81 -13
- benchmark_vector_db.py +3 -3
app.py
CHANGED
|
@@ -25,19 +25,85 @@ db = BenchmarkVectorDB(
|
|
| 25 |
embedding_model="all-MiniLM-L6-v2"
|
| 26 |
)
|
| 27 |
|
| 28 |
-
# Build database if not exists (first launch on Hugging Face)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
else:
|
| 40 |
-
logger.info(f"✓ Loaded existing database with {
|
| 41 |
|
| 42 |
def analyze_prompt(prompt: str, k: int = 5) -> str:
|
| 43 |
"""
|
|
@@ -75,7 +141,9 @@ def analyze_prompt(prompt: str, k: int = 5) -> str:
|
|
| 75 |
output.append(f" - Similarity: {q['similarity']:.3f}")
|
| 76 |
output.append("")
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
|
| 80 |
return "\n".join(output)
|
| 81 |
|
|
|
|
| 25 |
embedding_model="all-MiniLM-L6-v2"
|
| 26 |
)
|
| 27 |
|
| 28 |
+
# Build expanded database if not exists (first launch on Hugging Face)
|
| 29 |
+
current_count = db.collection.count()
|
| 30 |
+
|
| 31 |
+
if current_count == 0:
|
| 32 |
+
logger.info("Database is empty - building expanded database from scratch...")
|
| 33 |
+
logger.info("This will take ~10-15 minutes on first launch (building 26K+ questions).")
|
| 34 |
+
|
| 35 |
+
# Load MMLU-Pro test split for comprehensive coverage
|
| 36 |
+
try:
|
| 37 |
+
from datasets import load_dataset
|
| 38 |
+
from benchmark_vector_db import BenchmarkQuestion
|
| 39 |
+
|
| 40 |
+
# Load MMLU-Pro validation + test splits
|
| 41 |
+
logger.info("Loading MMLU-Pro validation split...")
|
| 42 |
+
val_dataset = load_dataset("TIGER-Lab/MMLU-Pro", split="validation")
|
| 43 |
+
logger.info(f" Loaded {len(val_dataset)} validation questions")
|
| 44 |
+
|
| 45 |
+
logger.info("Loading MMLU-Pro test split...")
|
| 46 |
+
test_dataset = load_dataset("TIGER-Lab/MMLU-Pro", split="test")
|
| 47 |
+
logger.info(f" Loaded {len(test_dataset)} test questions")
|
| 48 |
+
|
| 49 |
+
all_questions = []
|
| 50 |
+
|
| 51 |
+
# Process validation split
|
| 52 |
+
for idx, item in enumerate(val_dataset):
|
| 53 |
+
question = BenchmarkQuestion(
|
| 54 |
+
question_id=f"mmlu_pro_val_{idx}",
|
| 55 |
+
source_benchmark="MMLU_Pro",
|
| 56 |
+
domain=item.get('category', 'unknown').lower(),
|
| 57 |
+
question_text=item['question'],
|
| 58 |
+
correct_answer=item['answer'],
|
| 59 |
+
choices=item.get('options', []),
|
| 60 |
+
success_rate=0.45,
|
| 61 |
+
difficulty_score=0.55,
|
| 62 |
+
difficulty_label="Hard",
|
| 63 |
+
num_models_tested=0
|
| 64 |
+
)
|
| 65 |
+
all_questions.append(question)
|
| 66 |
+
|
| 67 |
+
# Process test split
|
| 68 |
+
for idx, item in enumerate(test_dataset):
|
| 69 |
+
question = BenchmarkQuestion(
|
| 70 |
+
question_id=f"mmlu_pro_test_{idx}",
|
| 71 |
+
source_benchmark="MMLU_Pro",
|
| 72 |
+
domain=item.get('category', 'unknown').lower(),
|
| 73 |
+
question_text=item['question'],
|
| 74 |
+
correct_answer=item['answer'],
|
| 75 |
+
choices=item.get('options', []),
|
| 76 |
+
success_rate=0.45,
|
| 77 |
+
difficulty_score=0.55,
|
| 78 |
+
difficulty_label="Hard",
|
| 79 |
+
num_models_tested=0
|
| 80 |
+
)
|
| 81 |
+
all_questions.append(question)
|
| 82 |
+
|
| 83 |
+
logger.info(f"Total questions to index: {len(all_questions)}")
|
| 84 |
+
|
| 85 |
+
# Index in batches of 1000 for stability
|
| 86 |
+
batch_size = 1000
|
| 87 |
+
for i in range(0, len(all_questions), batch_size):
|
| 88 |
+
batch = all_questions[i:i + batch_size]
|
| 89 |
+
batch_num = i // batch_size + 1
|
| 90 |
+
total_batches = (len(all_questions) + batch_size - 1) // batch_size
|
| 91 |
+
logger.info(f"Indexing batch {batch_num}/{total_batches} ({len(batch)} questions)...")
|
| 92 |
+
db.index_questions(batch)
|
| 93 |
+
|
| 94 |
+
logger.info(f"✓ Database build complete! Indexed {len(all_questions)} questions")
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Failed to build expanded database: {e}")
|
| 98 |
+
logger.info("Falling back to standard build...")
|
| 99 |
+
db.build_database(
|
| 100 |
+
load_gpqa=False, # Skip GPQA (requires auth)
|
| 101 |
+
load_mmlu_pro=True,
|
| 102 |
+
load_math=False, # Skip MATH (dataset path issues)
|
| 103 |
+
max_samples_per_dataset=1000
|
| 104 |
+
)
|
| 105 |
else:
|
| 106 |
+
logger.info(f"✓ Loaded existing database with {current_count:,} questions")
|
| 107 |
|
| 108 |
def analyze_prompt(prompt: str, k: int = 5) -> str:
|
| 109 |
"""
|
|
|
|
| 141 |
output.append(f" - Similarity: {q['similarity']:.3f}")
|
| 142 |
output.append("")
|
| 143 |
|
| 144 |
+
# Get current database size
|
| 145 |
+
total_questions = db.collection.count()
|
| 146 |
+
output.append(f"*Analyzed using {k} most similar questions from {total_questions:,} benchmark questions*")
|
| 147 |
|
| 148 |
return "\n".join(output)
|
| 149 |
|
benchmark_vector_db.py
CHANGED
|
@@ -568,9 +568,9 @@ class BenchmarkVectorDB:
|
|
| 568 |
if count == 0:
|
| 569 |
return {"total_questions": 0, "message": "No questions indexed yet"}
|
| 570 |
|
| 571 |
-
# Get
|
| 572 |
-
|
| 573 |
-
sample = self.collection.get(limit=
|
| 574 |
|
| 575 |
domains = defaultdict(int)
|
| 576 |
sources = defaultdict(int)
|
|
|
|
| 568 |
if count == 0:
|
| 569 |
return {"total_questions": 0, "message": "No questions indexed yet"}
|
| 570 |
|
| 571 |
+
# Get ALL questions for accurate stats (not just sample of 1000)
|
| 572 |
+
logger.info(f"Computing statistics from all {count} questions...")
|
| 573 |
+
sample = self.collection.get(limit=count, include=["metadatas"])
|
| 574 |
|
| 575 |
domains = defaultdict(int)
|
| 576 |
sources = defaultdict(int)
|