H2H-eval-comparator / mmlu_eval_original.py
rohansampath's picture
Update mmlu_eval_original.py
532a4a4 verified
raw
history blame
5.86 kB
import torch
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
import logging
import numpy as np
import pandas as pd
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
accuracy_metric = evaluate.load("accuracy")
choices = ["A", "B", "C", "D"]
MAX_CONTEXT_WINDOW = 4096 #Hard-coded for the moment, will be replaced later to be an input from the Model.
def load_dataset_from_hf(verbose=False):
mmlu_dataset = load_dataset("cais/mmlu", "all")
if verbose:
for split in mmlu_dataset.keys():
dataset = mmlu_dataset[split] # Access the dataset split
# Log number of rows and columns
num_rows = len(dataset)
num_cols = len(dataset.column_names)
logger.info(f"Dataset Split: {split}")
logger.info(f"Number of Rows: {num_rows}")
logger.info(f"Number of Columns: {num_cols}")
# Log column names and their types
column_types = {col: str(dataset.features[col].dtype) for col in dataset.column_names}
logger.info(f"Column Names: {dataset.column_names}")
logger.info(f"Column Types: {column_types}")
# Log a sample of 5 rows
sample_rows = dataset.select(range(min(5, num_rows))) # Ensure we don't exceed available rows
logger.info("Sample Rows:")
for row in sample_rows:
logger.info(row)
logger.info("=" * 50) # Separator for readability
return mmlu_dataset
def format_subject(subject):
l = subject.split("_")
s = ""
for entry in l:
s += " " + entry
return s
def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt
def gen_prompt(df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
format_subject(subject)
)
if k == -1:
k = df.shape[0]
for i in range(k):
prompt += format_example(df, i)
return prompt
@torch.no_grad()
def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
cors = []
all_probs = []
if (train_shots < 0):
train_shots = 0 # Make positive.
for i in range(test_df.shape[0]):
prompt_end = format_example(test_df, i, include_answer=False)
train_prompt = gen_prompt(dev_df, subject, train_shots)
prompt = train_prompt + prompt_end
input_ids = tokenizer (prompt, return_tensors="pt").input_ids.to(model.device)
# Reduce number of shots in the prompt to fit in context window.
while (train_shots > 0 and input_ids.shape[-1] > MAX_CONTEXT_WINDOW):
train_shots -= 1
train_prompt = gen_prompt(dev_df, subject, train_shots)
prompt = train_prompt + prompt_end
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
model.device
)
logger.info (f"Prompt: {prompt}")
label = test_df.iloc[i, test_df.shape[1] - 1]
logits = model(input_ids=input_ids).logits[0, -1]
probs = (
torch.nn.functional.softmax(
torch.tensor(
[
logits[tokenizer("A").input_ids[-1]],
logits[tokenizer("B").input_ids[-1]],
logits[tokenizer("C").input_ids[-1]],
logits[tokenizer("D").input_ids[-1]],
]
).float(),
dim=0,
)
.detach()
.cpu()
.numpy()
)
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
cor = pred == label
cors.append(cor)
all_probs.append(probs)
acc = np.mean(cors)
cors = np.array(cors)
all_probs = np.array(all_probs)
print("Average accuracy {:.3f} - {}".format(acc, subject))
return cors, acc, all_probs
def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=5):
"""
Evaluates the model on MMLU across all subjects.
"""
model.eval() # Ensure Dropout and BatchNorm behave appropriately for inference.
dataset = load_dataset_from_hf(verbose=True)
# Convert dataset partitions to pandas DataFrames
test_df = pd.DataFrame(dataset['test'])
dev_df = pd.DataFrame(dataset['dev'])
subjects = sorted(test_df['subject'].unique())
results = {}
correct_examples = []
incorrect_examples = []
all_accuracies = []
all_cors = []
for subject in subjects:
test_samples = test_df[test_df['subject'] == subject].head(num_questions)
dev_samples = dev_df[dev_df['subject'] == subject].head(num_shots)
# Log subject and sample counts
logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
cors, acc, probs = eval(subject, model, tokenizer, dev_samples, test_samples, num_questions_per_subject=num_questions, train_shots=num_shots)
all_cors.append(cors)
weighted_acc = np.mean(np.concatenate(all_cors))
return {
"overall_accuracy": weighted_acc,
"min_accuracy_subject": (min_acc_subject, results[min_acc_subject]),
"max_accuracy_subject": (max_acc_subject, results[max_acc_subject]),
"correct_examples": correct_examples,
"incorrect_examples": incorrect_examples,
}