update app.py
Browse files
app.py
CHANGED
|
@@ -14,7 +14,7 @@ zero = torch.Tensor([0]).to(device)
|
|
| 14 |
print(f"Device being used: {zero.device}")
|
| 15 |
|
| 16 |
@spaces.GPU
|
| 17 |
-
def evaluate_model(model_id):
|
| 18 |
model = SentenceTransformer(model_id, device=device)
|
| 19 |
matryoshka_dimensions = [768, 512, 256, 128, 64]
|
| 20 |
|
|
@@ -26,7 +26,7 @@ def evaluate_model(model_id):
|
|
| 26 |
"split": "train",
|
| 27 |
"size": 7000,
|
| 28 |
"columns": ("question", "context"),
|
| 29 |
-
"sample_size":
|
| 30 |
},
|
| 31 |
{
|
| 32 |
"name": "MLQA",
|
|
@@ -35,7 +35,7 @@ def evaluate_model(model_id):
|
|
| 35 |
"split": "validation",
|
| 36 |
"size": 500,
|
| 37 |
"columns": ("question", "context"),
|
| 38 |
-
"sample_size":
|
| 39 |
},
|
| 40 |
{
|
| 41 |
"name": "ARCD",
|
|
@@ -43,8 +43,8 @@ def evaluate_model(model_id):
|
|
| 43 |
"split": "train",
|
| 44 |
"size": None,
|
| 45 |
"columns": ("question", "context"),
|
| 46 |
-
"sample_size":
|
| 47 |
-
"last_rows": True # Take the last
|
| 48 |
}
|
| 49 |
]
|
| 50 |
|
|
@@ -58,7 +58,7 @@ def evaluate_model(model_id):
|
|
| 58 |
else:
|
| 59 |
dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"])
|
| 60 |
|
| 61 |
-
# Take last
|
| 62 |
if dataset_info.get("last_rows"):
|
| 63 |
dataset = dataset.select(range(len(dataset) - dataset_info["sample_size"], len(dataset)))
|
| 64 |
else:
|
|
@@ -136,13 +136,17 @@ def evaluate_model(model_id):
|
|
| 136 |
return result_df, charts[0], charts[1], charts[2]
|
| 137 |
|
| 138 |
# Define the Gradio interface
|
| 139 |
-
def display_results(model_name):
|
| 140 |
-
result_df, chart1, chart2, chart3 = evaluate_model(model_name)
|
| 141 |
return result_df, chart1, chart2, chart3
|
| 142 |
|
|
|
|
| 143 |
demo = gr.Interface(
|
| 144 |
fn=display_results,
|
| 145 |
-
inputs=
|
|
|
|
|
|
|
|
|
|
| 146 |
outputs=[
|
| 147 |
gr.Dataframe(label="Evaluation Results"),
|
| 148 |
gr.Plot(label="Financial Dataset"),
|
|
@@ -164,4 +168,7 @@ demo = gr.Interface(
|
|
| 164 |
css="footer {visibility: hidden;}"
|
| 165 |
)
|
| 166 |
|
| 167 |
-
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
print(f"Device being used: {zero.device}")
|
| 15 |
|
| 16 |
@spaces.GPU
|
| 17 |
+
def evaluate_model(model_id, num_questions):
|
| 18 |
model = SentenceTransformer(model_id, device=device)
|
| 19 |
matryoshka_dimensions = [768, 512, 256, 128, 64]
|
| 20 |
|
|
|
|
| 26 |
"split": "train",
|
| 27 |
"size": 7000,
|
| 28 |
"columns": ("question", "context"),
|
| 29 |
+
"sample_size": num_questions
|
| 30 |
},
|
| 31 |
{
|
| 32 |
"name": "MLQA",
|
|
|
|
| 35 |
"split": "validation",
|
| 36 |
"size": 500,
|
| 37 |
"columns": ("question", "context"),
|
| 38 |
+
"sample_size": num_questions
|
| 39 |
},
|
| 40 |
{
|
| 41 |
"name": "ARCD",
|
|
|
|
| 43 |
"split": "train",
|
| 44 |
"size": None,
|
| 45 |
"columns": ("question", "context"),
|
| 46 |
+
"sample_size": num_questions,
|
| 47 |
+
"last_rows": True # Take the last n rows
|
| 48 |
}
|
| 49 |
]
|
| 50 |
|
|
|
|
| 58 |
else:
|
| 59 |
dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"])
|
| 60 |
|
| 61 |
+
# Take the last n rows if specified
|
| 62 |
if dataset_info.get("last_rows"):
|
| 63 |
dataset = dataset.select(range(len(dataset) - dataset_info["sample_size"], len(dataset)))
|
| 64 |
else:
|
|
|
|
| 136 |
return result_df, charts[0], charts[1], charts[2]
|
| 137 |
|
| 138 |
# Define the Gradio interface
|
| 139 |
+
def display_results(model_name, num_questions):
|
| 140 |
+
result_df, chart1, chart2, chart3 = evaluate_model(model_name, num_questions)
|
| 141 |
return result_df, chart1, chart2, chart3
|
| 142 |
|
| 143 |
+
# Gradio interface with a slider to choose the number of questions (1 to 500)
|
| 144 |
demo = gr.Interface(
|
| 145 |
fn=display_results,
|
| 146 |
+
inputs=[
|
| 147 |
+
gr.Textbox(label="Enter a Hugging Face Model ID", placeholder="e.g., Omartificial-Intelligence-Space/GATE-AraBert-v1"),
|
| 148 |
+
gr.Slider(label="Number of Questions", minimum=1, maximum=500, step=1, value=500)
|
| 149 |
+
],
|
| 150 |
outputs=[
|
| 151 |
gr.Dataframe(label="Evaluation Results"),
|
| 152 |
gr.Plot(label="Financial Dataset"),
|
|
|
|
| 168 |
css="footer {visibility: hidden;}"
|
| 169 |
)
|
| 170 |
|
| 171 |
+
demo.launch(share=True)
|
| 172 |
+
|
| 173 |
+
# Add the footer
|
| 174 |
+
print("\nCreated by Omar Najar | Omartificial Intelligence Space")
|