update app.py
Browse files
app.py
CHANGED
|
@@ -18,26 +18,30 @@ def evaluate_model(model_id, num_questions):
|
|
| 18 |
model = SentenceTransformer(model_id, device=device)
|
| 19 |
matryoshka_dimensions = [768, 512, 256, 128, 64]
|
| 20 |
|
| 21 |
-
# Prepare datasets (
|
| 22 |
datasets_info = [
|
| 23 |
{
|
| 24 |
"name": "Financial",
|
| 25 |
"dataset_id": "Omartificial-Intelligence-Space/Arabic-finanical-rag-embedding-dataset",
|
| 26 |
-
"split":
|
| 27 |
-
"columns": ("question", "context")
|
|
|
|
| 28 |
},
|
| 29 |
{
|
| 30 |
"name": "MLQA",
|
| 31 |
"dataset_id": "google/xtreme",
|
| 32 |
"subset": "MLQA.ar.ar",
|
| 33 |
-
"split":
|
| 34 |
-
"columns": ("question", "context")
|
|
|
|
| 35 |
},
|
| 36 |
{
|
| 37 |
"name": "ARCD",
|
| 38 |
"dataset_id": "hsseinmz/arcd",
|
| 39 |
-
"split":
|
| 40 |
-
"columns": ("question", "context")
|
|
|
|
|
|
|
| 41 |
}
|
| 42 |
]
|
| 43 |
|
|
@@ -45,12 +49,18 @@ def evaluate_model(model_id, num_questions):
|
|
| 45 |
scores_by_dataset = {}
|
| 46 |
|
| 47 |
for dataset_info in datasets_info:
|
| 48 |
-
# Load the dataset
|
| 49 |
if "subset" in dataset_info:
|
| 50 |
dataset = load_dataset(dataset_info["dataset_id"], dataset_info["subset"], split=dataset_info["split"])
|
| 51 |
else:
|
| 52 |
dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"])
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
# Rename columns to 'anchor' and 'positive'
|
| 55 |
dataset = dataset.rename_column(dataset_info["columns"][0], "anchor")
|
| 56 |
dataset = dataset.rename_column(dataset_info["columns"][1], "positive")
|
|
|
|
| 18 |
model = SentenceTransformer(model_id, device=device)
|
| 19 |
matryoshka_dimensions = [768, 512, 256, 128, 64]
|
| 20 |
|
| 21 |
+
# Prepare datasets (Load entire split, then select num_questions)
|
| 22 |
datasets_info = [
|
| 23 |
{
|
| 24 |
"name": "Financial",
|
| 25 |
"dataset_id": "Omartificial-Intelligence-Space/Arabic-finanical-rag-embedding-dataset",
|
| 26 |
+
"split": "train", # Only train split
|
| 27 |
+
"columns": ("question", "context"),
|
| 28 |
+
"sample_size": num_questions
|
| 29 |
},
|
| 30 |
{
|
| 31 |
"name": "MLQA",
|
| 32 |
"dataset_id": "google/xtreme",
|
| 33 |
"subset": "MLQA.ar.ar",
|
| 34 |
+
"split": "validation", # Only validation split
|
| 35 |
+
"columns": ("question", "context"),
|
| 36 |
+
"sample_size": num_questions
|
| 37 |
},
|
| 38 |
{
|
| 39 |
"name": "ARCD",
|
| 40 |
"dataset_id": "hsseinmz/arcd",
|
| 41 |
+
"split": "train", # Only train split
|
| 42 |
+
"columns": ("question", "context"),
|
| 43 |
+
"sample_size": num_questions,
|
| 44 |
+
"last_rows": True # Take the last num_questions rows
|
| 45 |
}
|
| 46 |
]
|
| 47 |
|
|
|
|
| 49 |
scores_by_dataset = {}
|
| 50 |
|
| 51 |
for dataset_info in datasets_info:
|
| 52 |
+
# Load the full dataset split and limit it afterward
|
| 53 |
if "subset" in dataset_info:
|
| 54 |
dataset = load_dataset(dataset_info["dataset_id"], dataset_info["subset"], split=dataset_info["split"])
|
| 55 |
else:
|
| 56 |
dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"])
|
| 57 |
|
| 58 |
+
# Select the required number of rows
|
| 59 |
+
if dataset_info.get("last_rows"):
|
| 60 |
+
dataset = dataset.select(range(len(dataset) - dataset_info["sample_size"], len(dataset))) # Take last n rows
|
| 61 |
+
else:
|
| 62 |
+
dataset = dataset.select(range(min(dataset_info["sample_size"], len(dataset)))) # Take first n rows
|
| 63 |
+
|
| 64 |
# Rename columns to 'anchor' and 'positive'
|
| 65 |
dataset = dataset.rename_column(dataset_info["columns"][0], "anchor")
|
| 66 |
dataset = dataset.rename_column(dataset_info["columns"][1], "positive")
|