Spaces:
Sleeping
Sleeping
| from .base_benchmark import BaseBenchmark | |
| from typing import Dict, Any, Optional, Tuple | |
| from datasets import load_dataset | |
| import re | |
| import random | |
| from .evaluation_utils import extract_answer_mmlu | |
| class GPQABenchmark(BaseBenchmark): | |
| """GPQA (Graduate-Level Google-Proof Q&A) benchmark""" | |
| def __init__(self): | |
| super().__init__(name="GPQA", dataset_name="Idavidrein/gpqa") | |
| async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): | |
| """Load GPQA dataset""" | |
| # GPQA has different subsets: gpqa_main, gpqa_diamond, gpqa_extended | |
| subset = kwargs.get('subset', 'gpqa_main') | |
| try: | |
| # Set HF token if available | |
| import os | |
| hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGING_FACE_HUB_TOKEN') | |
| if hf_token: | |
| dataset = load_dataset(self.dataset_name, subset, split='train', token=hf_token) | |
| else: | |
| dataset = load_dataset(self.dataset_name, subset, split='train') | |
| except Exception as e: | |
| if "gated dataset" in str(e) or "authentication" in str(e).lower(): | |
| raise Exception( | |
| "GPQA dataset requires authentication. Please:\n" | |
| "1. Set HF_TOKEN environment variable\n" | |
| "2. Request access at https://huggingface.co/datasets/Idavidrein/gpqa\n" | |
| f"Original error: {e}" | |
| ) | |
| # Fallback to main if subset not found | |
| try: | |
| dataset = load_dataset(self.dataset_name, 'gpqa_main', split='train') | |
| except: | |
| raise e | |
| self.dataset = [] | |
| for sample in dataset: | |
| # GPQA has these fields: Question, Correct Answer, Incorrect Answer 1-3 | |
| choices = [ | |
| sample.get('Correct Answer', ''), | |
| sample.get('Incorrect Answer 1', ''), | |
| sample.get('Incorrect Answer 2', ''), | |
| sample.get('Incorrect Answer 3', '') | |
| ] | |
| # Shuffle choices and track correct index | |
| import random | |
| indices = list(range(4)) | |
| random.shuffle(indices) | |
| shuffled_choices = [choices[i] for i in indices] | |
| correct_index = indices.index(0) # 0 was the correct answer position | |
| self.dataset.append({ | |
| 'question': sample['Question'], | |
| 'choices': shuffled_choices, | |
| 'correct_index': correct_index, | |
| 'subject': sample.get('Subdomain', 'Unknown'), | |
| 'raw_sample': sample | |
| }) | |
| # Shuffle dataset | |
| random.shuffle(self.dataset) | |
| if sample_size and len(self.dataset) > sample_size: | |
| self.dataset = self.dataset[:sample_size] | |
| def format_prompt(self, sample: Dict[str, Any]) -> str: | |
| """Format GPQA question as prompt matching official format""" | |
| question = sample['question'] | |
| choices = sample['choices'] | |
| # GPQA uses a simpler format in lm-eval | |
| prompt = f"""What is the correct answer to this question: {question} | |
| Choices: | |
| (A) {choices[0]} | |
| (B) {choices[1]} | |
| (C) {choices[2]} | |
| (D) {choices[3]} | |
| Answer:""" | |
| return prompt | |
| async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: | |
| """Evaluate a single GPQA sample""" | |
| prompt = self.format_prompt(sample) | |
| try: | |
| response = await api.generate_with_retry(prompt, **kwargs) | |
| # Extract answer from response using standard extraction | |
| predicted_letter = extract_answer_mmlu(response) | |
| if predicted_letter: | |
| predicted_index = ord(predicted_letter) - ord('A') | |
| else: | |
| # If no clear answer, mark as incorrect | |
| predicted_index = -1 | |
| correct_index = sample['correct_index'] | |
| is_correct = predicted_index == correct_index | |
| result = { | |
| 'question': sample['question'], | |
| 'choices': sample['choices'], | |
| 'correct_answer': correct_index, | |
| 'predicted_answer': predicted_index, | |
| 'model_response': response, | |
| 'is_correct': is_correct, | |
| 'subject': sample['subject'] | |
| } | |
| return is_correct, result | |
| except Exception as e: | |
| result = { | |
| 'question': sample['question'], | |
| 'error': str(e), | |
| 'is_correct': False | |
| } | |
| return False, result |