rohansampath commited on
Commit
714de6d
·
verified ·
1 Parent(s): 532a4a4

Update mmlu_eval_original.py

Browse files
Files changed (1) hide show
  1. mmlu_eval_original.py +30 -4
mmlu_eval_original.py CHANGED
@@ -76,6 +76,9 @@ def gen_prompt(df, subject, k=-1):
76
 
77
  @torch.no_grad()
78
  def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
 
 
 
79
  cors = []
80
  all_probs = []
81
 
@@ -125,6 +128,12 @@ def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=
125
 
126
  cor = pred == label
127
 
 
 
 
 
 
 
128
  cors.append(cor)
129
  all_probs.append(probs)
130
 
@@ -151,6 +160,10 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
151
  test_df = pd.DataFrame(dataset['test'])
152
  dev_df = pd.DataFrame(dataset['dev'])
153
 
 
 
 
 
154
  subjects = sorted(test_df['subject'].unique())
155
 
156
  results = {}
@@ -158,6 +171,7 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
158
  incorrect_examples = []
159
  all_accuracies = []
160
  all_cors = []
 
161
 
162
  for subject in subjects:
163
  test_samples = test_df[test_df['subject'] == subject].head(num_questions)
@@ -167,15 +181,27 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
167
  logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
168
 
169
  cors, acc, probs = eval(subject, model, tokenizer, dev_samples, test_samples, num_questions_per_subject=num_questions, train_shots=num_shots)
 
170
  all_cors.append(cors)
171
-
172
- weighted_acc = np.mean(np.concatenate(all_cors))
 
 
 
 
 
173
 
174
 
 
 
 
 
 
 
 
175
  return {
176
  "overall_accuracy": weighted_acc,
177
  "min_accuracy_subject": (min_acc_subject, results[min_acc_subject]),
178
  "max_accuracy_subject": (max_acc_subject, results[max_acc_subject]),
179
- "correct_examples": correct_examples,
180
- "incorrect_examples": incorrect_examples,
181
  }
 
76
 
77
  @torch.no_grad()
78
  def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
79
+ assert all(dev_df['subject'] == subject), f"Not all items in dev_df match subject {subject}"
80
+ assert all(test_df['subject'] == subject), f"Not all items in test_df match subject {subject}"
81
+
82
  cors = []
83
  all_probs = []
84
 
 
128
 
129
  cor = pred == label
130
 
131
+ logger.info(f"Label: {label}")
132
+ logger.info(f"Logits: {logits}")
133
+ logger.info(f"Probabilities: {probs}")
134
+ logger.info(f"Prediction: {pred}")
135
+ logger.info(f"Correct: {cor}")
136
+
137
  cors.append(cor)
138
  all_probs.append(probs)
139
 
 
160
  test_df = pd.DataFrame(dataset['test'])
161
  dev_df = pd.DataFrame(dataset['dev'])
162
 
163
+ # Sort datasets by subject and other relevant columns
164
+ test_df = test_df.sort_values(['subject', 'question'])
165
+ dev_df = dev_df.sort_values(['subject', 'question'])
166
+
167
  subjects = sorted(test_df['subject'].unique())
168
 
169
  results = {}
 
171
  incorrect_examples = []
172
  all_accuracies = []
173
  all_cors = []
174
+ results_table = []
175
 
176
  for subject in subjects:
177
  test_samples = test_df[test_df['subject'] == subject].head(num_questions)
 
181
  logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
182
 
183
  cors, acc, probs = eval(subject, model, tokenizer, dev_samples, test_samples, num_questions_per_subject=num_questions, train_shots=num_shots)
184
+ results[subject] = acc
185
  all_cors.append(cors)
186
+
187
+ results_table.append({
188
+ 'Subject': subject,
189
+ 'Num_samples': len(test_samples),
190
+ 'Num_correct': int(np.sum(cors)),
191
+ 'Accuracy': acc
192
+ })
193
 
194
 
195
+
196
+
197
+ weighted_acc = np.mean(np.concatenate(all_cors))
198
+
199
+ min_acc_subject = min(results.items(), key=lambda x: x[1])[0]
200
+ max_acc_subject = max(results.items(), key=lambda x: x[1])[0]
201
+
202
  return {
203
  "overall_accuracy": weighted_acc,
204
  "min_accuracy_subject": (min_acc_subject, results[min_acc_subject]),
205
  "max_accuracy_subject": (max_acc_subject, results[max_acc_subject]),
206
+ "full_accuracy_table": results_table
 
207
  }