Spaces:
Sleeping
Sleeping
Update mmlu_eval_original.py
Browse files- mmlu_eval_original.py +30 -4
mmlu_eval_original.py
CHANGED
|
@@ -76,6 +76,9 @@ def gen_prompt(df, subject, k=-1):
|
|
| 76 |
|
| 77 |
@torch.no_grad()
|
| 78 |
def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
|
|
|
|
|
|
|
|
|
|
| 79 |
cors = []
|
| 80 |
all_probs = []
|
| 81 |
|
|
@@ -125,6 +128,12 @@ def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=
|
|
| 125 |
|
| 126 |
cor = pred == label
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
cors.append(cor)
|
| 129 |
all_probs.append(probs)
|
| 130 |
|
|
@@ -151,6 +160,10 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
|
|
| 151 |
test_df = pd.DataFrame(dataset['test'])
|
| 152 |
dev_df = pd.DataFrame(dataset['dev'])
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
subjects = sorted(test_df['subject'].unique())
|
| 155 |
|
| 156 |
results = {}
|
|
@@ -158,6 +171,7 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
|
|
| 158 |
incorrect_examples = []
|
| 159 |
all_accuracies = []
|
| 160 |
all_cors = []
|
|
|
|
| 161 |
|
| 162 |
for subject in subjects:
|
| 163 |
test_samples = test_df[test_df['subject'] == subject].head(num_questions)
|
|
@@ -167,15 +181,27 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
|
|
| 167 |
logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
|
| 168 |
|
| 169 |
cors, acc, probs = eval(subject, model, tokenizer, dev_samples, test_samples, num_questions_per_subject=num_questions, train_shots=num_shots)
|
|
|
|
| 170 |
all_cors.append(cors)
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
return {
|
| 176 |
"overall_accuracy": weighted_acc,
|
| 177 |
"min_accuracy_subject": (min_acc_subject, results[min_acc_subject]),
|
| 178 |
"max_accuracy_subject": (max_acc_subject, results[max_acc_subject]),
|
| 179 |
-
"
|
| 180 |
-
"incorrect_examples": incorrect_examples,
|
| 181 |
}
|
|
|
|
| 76 |
|
| 77 |
@torch.no_grad()
|
| 78 |
def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
|
| 79 |
+
assert all(dev_df['subject'] == subject), f"Not all items in dev_df match subject {subject}"
|
| 80 |
+
assert all(test_df['subject'] == subject), f"Not all items in test_df match subject {subject}"
|
| 81 |
+
|
| 82 |
cors = []
|
| 83 |
all_probs = []
|
| 84 |
|
|
|
|
| 128 |
|
| 129 |
cor = pred == label
|
| 130 |
|
| 131 |
+
logger.info(f"Label: {label}")
|
| 132 |
+
logger.info(f"Logits: {logits}")
|
| 133 |
+
logger.info(f"Probabilities: {probs}")
|
| 134 |
+
logger.info(f"Prediction: {pred}")
|
| 135 |
+
logger.info(f"Correct: {cor}")
|
| 136 |
+
|
| 137 |
cors.append(cor)
|
| 138 |
all_probs.append(probs)
|
| 139 |
|
|
|
|
| 160 |
test_df = pd.DataFrame(dataset['test'])
|
| 161 |
dev_df = pd.DataFrame(dataset['dev'])
|
| 162 |
|
| 163 |
+
# Sort datasets by subject and other relevant columns
|
| 164 |
+
test_df = test_df.sort_values(['subject', 'question'])
|
| 165 |
+
dev_df = dev_df.sort_values(['subject', 'question'])
|
| 166 |
+
|
| 167 |
subjects = sorted(test_df['subject'].unique())
|
| 168 |
|
| 169 |
results = {}
|
|
|
|
| 171 |
incorrect_examples = []
|
| 172 |
all_accuracies = []
|
| 173 |
all_cors = []
|
| 174 |
+
results_table = []
|
| 175 |
|
| 176 |
for subject in subjects:
|
| 177 |
test_samples = test_df[test_df['subject'] == subject].head(num_questions)
|
|
|
|
| 181 |
logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
|
| 182 |
|
| 183 |
cors, acc, probs = eval(subject, model, tokenizer, dev_samples, test_samples, num_questions_per_subject=num_questions, train_shots=num_shots)
|
| 184 |
+
results[subject] = acc
|
| 185 |
all_cors.append(cors)
|
| 186 |
+
|
| 187 |
+
results_table.append({
|
| 188 |
+
'Subject': subject,
|
| 189 |
+
'Num_samples': len(test_samples),
|
| 190 |
+
'Num_correct': int(np.sum(cors)),
|
| 191 |
+
'Accuracy': acc
|
| 192 |
+
})
|
| 193 |
|
| 194 |
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
weighted_acc = np.mean(np.concatenate(all_cors))
|
| 198 |
+
|
| 199 |
+
min_acc_subject = min(results.items(), key=lambda x: x[1])[0]
|
| 200 |
+
max_acc_subject = max(results.items(), key=lambda x: x[1])[0]
|
| 201 |
+
|
| 202 |
return {
|
| 203 |
"overall_accuracy": weighted_acc,
|
| 204 |
"min_accuracy_subject": (min_acc_subject, results[min_acc_subject]),
|
| 205 |
"max_accuracy_subject": (max_acc_subject, results[max_acc_subject]),
|
| 206 |
+
"full_accuracy_table": results_table
|
|
|
|
| 207 |
}
|