rohansampath commited on
Commit
bd05b7b
·
verified ·
1 Parent(s): b7b8b1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -110
app.py CHANGED
@@ -1,109 +1,111 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import os
5
  from huggingface_hub import login
6
- from toy_dataset_eval import evaluate_toy_dataset
7
  from mmlu_pro_eval_adapted import evaluate_mmlu_pro
8
  import spaces
9
  import pandas as pd
10
- import time # Added for timing functionality
 
11
 
12
  # Read token and login
13
  hf_token = os.getenv("HF_READ_WRITE_TOKEN")
14
  if hf_token:
15
  login(hf_token)
16
  else:
17
- print("⚠️ No HF_TOKEN_READ_WRITE found in environment")
18
 
19
  # ---------------------------------------------------------------------------
20
- # 1. Model and tokenizer setup and Loading
21
  # ---------------------------------------------------------------------------
22
  model_name = "mistralai/Mistral-7B-v0.1"
23
- tokenizer = None
24
- model = None
25
- model_loaded = False
26
-
27
  # ---------------------------------------------------------------------------
28
- # 1. MMLU-Pro Evaluation call
29
  # ---------------------------------------------------------------------------
30
- @spaces.GPU(duration=120) # Allow up to 2 minutes for full evaluation
31
  def run_mmlu_evaluation(all_subjects, num_subjects, num_shots, all_questions, num_questions, progress=gr.Progress()):
32
  """
33
  Runs the MMLU evaluation with the specified parameters.
34
 
35
  Args:
36
  all_subjects (bool): Whether to evaluate all subjects
37
- num_subjects (int): Number of subjects to evaluate (1-57)
38
  num_shots (int): Number of few-shot examples (0-5)
39
  all_questions (bool): Whether to evaluate all questions per subject
40
- num_questions (int): Number of examples per subject (1-20 or -1 for all)
41
  progress (gr.Progress): Progress indicator
42
  """
43
-
44
- # Convert num_subjects to -1 if all_subjects is True
45
- if all_subjects:
46
- num_subjects = -1
47
-
48
- # Convert num_questions to -1 if all_questions is True
49
- if all_questions:
50
- num_questions = -1
51
 
52
- # Run evaluation with timing
53
- start_time = time.time() # Start timing
54
- results = evaluate_mmlu_pro(
55
- model_name,
56
- num_subjects=num_subjects,
57
- num_questions=num_questions,
58
- num_shots=num_shots,
59
- )
60
- elapsed_time = time.time() - start_time # Calculate elapsed time
61
 
62
- # Format results
63
- overall_acc = results["overall_accuracy"]
64
- min_subject, min_acc = results["min_accuracy_subject"]
65
- max_subject, max_acc = results["max_accuracy_subject"]
66
-
67
- # Create DataFrame from results table
68
- results_df = pd.DataFrame(results["full_accuracy_table"])
69
-
70
- # Calculate totals for the overall row
71
- total_samples = results_df['Num_samples'].sum()
72
- total_correct = results_df['Num_correct'].sum()
73
-
74
- # Create overall row
75
- overall_row = pd.DataFrame({
76
- 'Subject': ['**Overall**'],
77
- 'Num_samples': [total_samples],
78
- 'Num_correct': [total_correct],
79
- 'Accuracy': [overall_acc]
80
- })
81
-
82
- # Concatenate overall row with results
83
- results_df = pd.concat([overall_row, results_df], ignore_index=True)
84
-
85
- # Verify that the overall accuracy is consistent with the total correct/total samples
86
- assert abs(overall_acc - (total_correct / total_samples)) < 1e-6, \
87
- "Overall accuracy calculation mismatch detected"
88
-
89
- # Format the report
90
- report = (
91
- f"### Overall Results\n"
92
- f"* Overall Accuracy: {overall_acc:.3f}\n"
93
- f"* Best Performance: {max_subject} ({max_acc:.3f})\n"
94
- f"* Worst Performance: {min_subject} ({min_acc:.3f})\n"
95
- f"* Evaluation completed in {elapsed_time:.2f} seconds\n"
96
- )
97
 
98
- # Return values that re-enable UI components after completion
99
- return (report, results_df,
100
- gr.update(interactive=True), gr.update(visible=False),
101
- gr.update(interactive=True), gr.update(interactive=True),
102
- gr.update(interactive=True), gr.update(interactive=True),
103
- gr.update(interactive=True))
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # ---------------------------------------------------------------------------
106
- # 4. Gradio Interface
107
  # ---------------------------------------------------------------------------
108
  with gr.Blocks() as demo:
109
  gr.Markdown("# Mistral-7B on MMLU-Pro Evaluation Demo")
@@ -117,13 +119,13 @@ with gr.Blocks() as demo:
117
  with gr.Row():
118
  all_subjects_checkbox = gr.Checkbox(
119
  label="Evaluate All Subjects",
120
- value=False, # Default is unchecked
121
  info="When checked, evaluates all 14 MMLU-Pro subjects"
122
  )
123
  num_subjects_slider = gr.Slider(
124
  minimum=1,
125
  maximum=14,
126
- value=14, # Default is all subjects
127
  step=1,
128
  label="Number of Subjects",
129
  info="Number of subjects to evaluate (1-14). They will be loaded in alphabetical order.",
@@ -134,16 +136,16 @@ with gr.Blocks() as demo:
134
  num_shots_slider = gr.Slider(
135
  minimum=0,
136
  maximum=5,
137
- value=5, # Default is 5 few-shot examples
138
  step=1,
139
  label="Number of Few-shot Examples",
140
- info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order."
141
  )
142
 
143
  with gr.Row():
144
  all_questions_checkbox = gr.Checkbox(
145
  label="Evaluate All Questions",
146
- value=False, # Default is unchecked
147
  info="When checked, evaluates all available questions for each subject"
148
  )
149
  questions_info_text = gr.Markdown(visible=False, value="**All 12,032 questions across all subjects will be evaluated**")
@@ -151,33 +153,32 @@ with gr.Blocks() as demo:
151
  with gr.Row(elem_id="questions_selection_row"):
152
  questions_container = gr.Column(scale=1, elem_id="questions_slider_container")
153
 
154
- # Move the slider into the container for easier visibility toggling
155
  with questions_container:
156
  num_questions_slider = gr.Slider(
157
  minimum=1,
158
  maximum=40,
159
- value=20, # Default is 10 questions
160
  step=1,
161
  label="Questions per Subject",
162
- info="Choose a subset of questions (1-40) per subject. They will be loaded in order of question_id for reproducibility. ",
163
  interactive=True
164
  )
165
 
166
  with gr.Row():
167
  with gr.Column(scale=1):
168
  eval_mmlu_button = gr.Button("Run MMLU-Pro Evaluation", variant="primary", interactive=True)
169
- cancel_mmlu_button = gr.Button("Cancel MMLU-Pro Evaluation", variant="stop", visible=False)
170
  results_output = gr.Markdown(label="Evaluation Results")
171
 
172
  with gr.Row():
173
  results_table = gr.DataFrame(interactive=True, label="Detailed Results (Sortable)", visible=True)
174
-
 
 
 
175
  # Update num_subjects_slider interactivity based on all_subjects checkbox
176
  def update_subjects_slider(checked):
177
- if checked:
178
- return gr.update(value=14, interactive=False)
179
- else:
180
- return gr.update(interactive=True)
181
 
182
  all_subjects_checkbox.change(
183
  fn=update_subjects_slider,
@@ -199,45 +200,75 @@ with gr.Blocks() as demo:
199
  )
200
 
201
  # Function to disable UI components during evaluation
202
- def disable_ui_for_evaluation():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  return [
204
- gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # all_subjects_checkbox
205
- gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # num_subjects_slider
206
- gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # num_shots_slider
207
- gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # all_questions_checkbox
208
- gr.update(interactive=False, info="MMLU Evaluation currently in progress"), # num_questions_slider
 
209
  gr.update(interactive=False), # eval_mmlu_button
210
- gr.update(visible=True) # cancel_mmlu_button
 
 
211
  ]
212
 
 
 
 
 
 
213
  # Function to handle cancel button click
214
- def cancel_evaluation():
215
- # This doesn't actually cancel the GPU job (which would require more backend support)
216
- # But it does reset the UI state to be interactive again
 
217
  return [
218
- gr.update(interactive=True, info="When checked, evaluates all 14 MMLU-Pro subjects"), # all_subjects_checkbox
219
- gr.update(interactive=True, info="Number of subjects to evaluate (1-14). They will be loaded in alphabetical order."), # num_subjects_slider
220
- gr.update(interactive=True, info="Number of examples to use for few-shot learning (0-5). They will be loaded in alphabetical order."), # num_shots_slider
221
- gr.update(interactive=True, info="When checked, evaluates all available questions for each subject"), # all_questions_checkbox
222
- gr.update(interactive=True, info="Choose a subset of questions (1-40) per subject. They will be loaded in order of question_id for reproducibility."), # num_questions_slider
 
223
  gr.update(interactive=True), # eval_mmlu_button
224
- gr.update(visible=False), # cancel_mmlu_button
225
- "⚠️ Evaluation canceled by user", # results_output
226
- None # results_table
227
  ]
228
 
229
- # Connect MMLU evaluation button - now disables UI and shows cancel button
230
  eval_mmlu_button.click(
231
- fn=disable_ui_for_evaluation,
232
- inputs=None,
233
  outputs=[
 
234
  all_subjects_checkbox,
235
  num_subjects_slider,
236
  num_shots_slider,
237
  all_questions_checkbox,
238
  num_questions_slider,
239
  eval_mmlu_button,
240
- cancel_mmlu_button
 
 
241
  ]
242
  ).then(
243
  fn=run_mmlu_evaluation,
@@ -259,13 +290,18 @@ with gr.Blocks() as demo:
259
  all_questions_checkbox,
260
  num_questions_slider
261
  ]
 
 
 
 
262
  )
263
 
264
  # Connect cancel button
265
  cancel_mmlu_button.click(
266
  fn=cancel_evaluation,
267
- inputs=None,
268
  outputs=[
 
269
  all_subjects_checkbox,
270
  num_subjects_slider,
271
  num_shots_slider,
 
1
  import gradio as gr
 
 
2
  import os
3
  from huggingface_hub import login
 
4
  from mmlu_pro_eval_adapted import evaluate_mmlu_pro
5
  import spaces
6
  import pandas as pd
7
+ import time
8
+ import traceback
9
 
10
  # Read token and login
11
  hf_token = os.getenv("HF_READ_WRITE_TOKEN")
12
  if hf_token:
13
  login(hf_token)
14
  else:
15
+ print("⚠️ No HF_READ_WRITE_TOKEN found in environment")
16
 
17
  # ---------------------------------------------------------------------------
18
+ # 1. Model configuration
19
  # ---------------------------------------------------------------------------
20
  model_name = "mistralai/Mistral-7B-v0.1"
21
+
 
 
 
22
  # ---------------------------------------------------------------------------
23
+ # 2. MMLU-Pro Evaluation
24
  # ---------------------------------------------------------------------------
25
+ @spaces.GPU(duration=180) # Extended to 3 minutes for larger evaluations
26
  def run_mmlu_evaluation(all_subjects, num_subjects, num_shots, all_questions, num_questions, progress=gr.Progress()):
27
  """
28
  Runs the MMLU evaluation with the specified parameters.
29
 
30
  Args:
31
  all_subjects (bool): Whether to evaluate all subjects
32
+ num_subjects (int): Number of subjects to evaluate (1-14)
33
  num_shots (int): Number of few-shot examples (0-5)
34
  all_questions (bool): Whether to evaluate all questions per subject
35
+ num_questions (int): Number of examples per subject (1-40 or all)
36
  progress (gr.Progress): Progress indicator
37
  """
38
+ try:
39
+ # Convert parameters if needed
40
+ if all_subjects:
41
+ num_subjects = -1
42
+
43
+ if all_questions:
44
+ num_questions = -1
 
45
 
46
+ # Run evaluation with timing
47
+ start_time = time.time()
48
+ results = evaluate_mmlu_pro(
49
+ model_name,
50
+ num_subjects=num_subjects,
51
+ num_questions=num_questions,
52
+ num_shots=num_shots,
53
+ )
54
+ elapsed_time = time.time() - start_time
55
 
56
+ # Format results
57
+ overall_acc = results["overall_accuracy"]
58
+ min_subject, min_acc = results["min_accuracy_subject"]
59
+ max_subject, max_acc = results["max_accuracy_subject"]
60
+
61
+ # Create DataFrame from results table
62
+ results_df = pd.DataFrame(results["full_accuracy_table"])
63
+
64
+ # Calculate totals for the overall row
65
+ total_samples = results_df['Num_samples'].sum()
66
+ total_correct = results_df['Num_correct'].sum()
67
+
68
+ # Create overall row
69
+ overall_row = pd.DataFrame({
70
+ 'Subject': ['**Overall**'],
71
+ 'Num_samples': [total_samples],
72
+ 'Num_correct': [total_correct],
73
+ 'Accuracy': [overall_acc]
74
+ })
75
+
76
+ # Concatenate overall row with results
77
+ results_df = pd.concat([overall_row, results_df], ignore_index=True)
78
+
79
+ # Format the report
80
+ report = (
81
+ f"### Overall Results\n"
82
+ f"* Overall Accuracy: {overall_acc:.3f}\n"
83
+ f"* Best Performance: {max_subject} ({max_acc:.3f})\n"
84
+ f"* Worst Performance: {min_subject} ({min_acc:.3f})\n"
85
+ f"* Evaluation completed in {elapsed_time:.2f} seconds\n"
86
+ )
 
 
 
 
87
 
88
+ # Return values that re-enable UI components after completion
89
+ return (report, results_df,
90
+ gr.update(interactive=True), gr.update(visible=False),
91
+ gr.update(interactive=True), gr.update(interactive=True),
92
+ gr.update(interactive=True), gr.update(interactive=True),
93
+ gr.update(interactive=True))
94
+
95
+ except Exception as e:
96
+ # Handle errors gracefully
97
+ error_trace = traceback.format_exc()
98
+ error_message = f"### Error during evaluation\n```\n{error_trace}\n```"
99
+
100
+ # Re-enable UI components on error
101
+ return (error_message, None,
102
+ gr.update(interactive=True), gr.update(visible=False),
103
+ gr.update(interactive=True), gr.update(interactive=True),
104
+ gr.update(interactive=True), gr.update(interactive=True),
105
+ gr.update(interactive=True))
106
 
107
  # ---------------------------------------------------------------------------
108
+ # 3. Gradio Interface
109
  # ---------------------------------------------------------------------------
110
  with gr.Blocks() as demo:
111
  gr.Markdown("# Mistral-7B on MMLU-Pro Evaluation Demo")
 
119
  with gr.Row():
120
  all_subjects_checkbox = gr.Checkbox(
121
  label="Evaluate All Subjects",
122
+ value=False,
123
  info="When checked, evaluates all 14 MMLU-Pro subjects"
124
  )
125
  num_subjects_slider = gr.Slider(
126
  minimum=1,
127
  maximum=14,
128
+ value=14,
129
  step=1,
130
  label="Number of Subjects",
131
  info="Number of subjects to evaluate (1-14). They will be loaded in alphabetical order.",
 
136
  num_shots_slider = gr.Slider(
137
  minimum=0,
138
  maximum=5,
139
+ value=5,
140
  step=1,
141
  label="Number of Few-shot Examples",
142
+ info="Number of examples to use for few-shot learning (0-5)."
143
  )
144
 
145
  with gr.Row():
146
  all_questions_checkbox = gr.Checkbox(
147
  label="Evaluate All Questions",
148
+ value=False,
149
  info="When checked, evaluates all available questions for each subject"
150
  )
151
  questions_info_text = gr.Markdown(visible=False, value="**All 12,032 questions across all subjects will be evaluated**")
 
153
  with gr.Row(elem_id="questions_selection_row"):
154
  questions_container = gr.Column(scale=1, elem_id="questions_slider_container")
155
 
 
156
  with questions_container:
157
  num_questions_slider = gr.Slider(
158
  minimum=1,
159
  maximum=40,
160
+ value=20,
161
  step=1,
162
  label="Questions per Subject",
163
+ info="Choose a subset of questions (1-40) per subject. They will be loaded in order of question_id.",
164
  interactive=True
165
  )
166
 
167
  with gr.Row():
168
  with gr.Column(scale=1):
169
  eval_mmlu_button = gr.Button("Run MMLU-Pro Evaluation", variant="primary", interactive=True)
170
+ cancel_mmlu_button = gr.Button("Cancel Evaluation", variant="stop", visible=False)
171
  results_output = gr.Markdown(label="Evaluation Results")
172
 
173
  with gr.Row():
174
  results_table = gr.DataFrame(interactive=True, label="Detailed Results (Sortable)", visible=True)
175
+
176
+ # Track evaluation state - used to prevent multiple evaluations
177
+ evaluation_state = gr.State({"running": False})
178
+
179
  # Update num_subjects_slider interactivity based on all_subjects checkbox
180
  def update_subjects_slider(checked):
181
+ return gr.update(interactive=not checked)
 
 
 
182
 
183
  all_subjects_checkbox.change(
184
  fn=update_subjects_slider,
 
200
  )
201
 
202
  # Function to disable UI components during evaluation
203
+ def start_evaluation(state):
204
+ if state["running"]:
205
+ return [
206
+ state,
207
+ gr.update(interactive=False),
208
+ gr.update(interactive=False),
209
+ gr.update(interactive=False),
210
+ gr.update(interactive=False),
211
+ gr.update(interactive=False),
212
+ gr.update(interactive=False),
213
+ gr.update(visible=False),
214
+ "Evaluation already in progress. Please wait.",
215
+ None
216
+ ]
217
+
218
+ # Update state to running
219
+ state["running"] = True
220
+
221
  return [
222
+ state,
223
+ gr.update(interactive=False), # all_subjects_checkbox
224
+ gr.update(interactive=False), # num_subjects_slider
225
+ gr.update(interactive=False), # num_shots_slider
226
+ gr.update(interactive=False), # all_questions_checkbox
227
+ gr.update(interactive=False), # num_questions_slider
228
  gr.update(interactive=False), # eval_mmlu_button
229
+ gr.update(visible=True), # cancel_mmlu_button
230
+ "Starting evaluation...", # results_output
231
+ None # results_table
232
  ]
233
 
234
+ # Function to reset UI after evaluation
235
+ def finish_evaluation(state):
236
+ state["running"] = False
237
+ return state
238
+
239
  # Function to handle cancel button click
240
+ def cancel_evaluation(state):
241
+ # Note: This doesn't actually stop the evaluation process
242
+ # It only updates the UI state to appear canceled
243
+ state["running"] = False
244
  return [
245
+ state,
246
+ gr.update(interactive=True), # all_subjects_checkbox
247
+ gr.update(interactive=True), # num_subjects_slider
248
+ gr.update(interactive=True), # num_shots_slider
249
+ gr.update(interactive=True), # all_questions_checkbox
250
+ gr.update(interactive=True), # num_questions_slider
251
  gr.update(interactive=True), # eval_mmlu_button
252
+ gr.update(visible=False), # cancel_mmlu_button
253
+ "⚠️ Evaluation canceled by user (note: backend process may continue running)", # results_output
254
+ None # results_table
255
  ]
256
 
257
+ # Connect MMLU evaluation button with state tracking
258
  eval_mmlu_button.click(
259
+ fn=start_evaluation,
260
+ inputs=[evaluation_state],
261
  outputs=[
262
+ evaluation_state,
263
  all_subjects_checkbox,
264
  num_subjects_slider,
265
  num_shots_slider,
266
  all_questions_checkbox,
267
  num_questions_slider,
268
  eval_mmlu_button,
269
+ cancel_mmlu_button,
270
+ results_output,
271
+ results_table
272
  ]
273
  ).then(
274
  fn=run_mmlu_evaluation,
 
290
  all_questions_checkbox,
291
  num_questions_slider
292
  ]
293
+ ).then(
294
+ fn=finish_evaluation,
295
+ inputs=[evaluation_state],
296
+ outputs=[evaluation_state]
297
  )
298
 
299
  # Connect cancel button
300
  cancel_mmlu_button.click(
301
  fn=cancel_evaluation,
302
+ inputs=[evaluation_state],
303
  outputs=[
304
+ evaluation_state,
305
  all_subjects_checkbox,
306
  num_subjects_slider,
307
  num_shots_slider,