Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| from huggingface_hub import HfApi | |
| DATASETS = [ | |
| "mMARCO-fr", | |
| "BSARD", | |
| ] | |
| SINGLE_VECTOR_MODELS = [ | |
| "antoinelouis/biencoder-camemberta-base-mmarcoFR", | |
| "antoinelouis/biencoder-camembert-base-mmarcoFR", | |
| "antoinelouis/biencoder-distilcamembert-mmarcoFR", | |
| "antoinelouis/biencoder-camembert-L10-mmarcoFR", | |
| "antoinelouis/biencoder-camembert-L8-mmarcoFR", | |
| "antoinelouis/biencoder-camembert-L6-mmarcoFR", | |
| "antoinelouis/biencoder-camembert-L4-mmarcoFR", | |
| "antoinelouis/biencoder-camembert-L2-mmarcoFR", | |
| "antoinelouis/biencoder-electra-base-mmarcoFR", | |
| "antoinelouis/biencoder-mMiniLMv2-L12-mmarcoFR", | |
| "antoinelouis/biencoder-mMiniLMv2-L6-mmarcoFR", | |
| "antoinelouis/biencoder-mdebertav3-mmarcoFR", | |
| "OrdalieTech/Solon-embeddings-large-0.1", | |
| "OrdalieTech/Solon-embeddings-base-0.1", | |
| ] | |
| MULTI_VECTOR_MODELS = [ | |
| "antoinelouis/colbertv1-camembert-base-mmarcoFR", | |
| "antoinelouis/colbertv2-camembert-L4-mmarcoFR", | |
| "antoinelouis/colbert-xm", | |
| ] | |
| SPARSE_LEXICAL_MODELS = [ | |
| "antoinelouis/spladev2-camembert-base-mmarcoFR", | |
| ] | |
| CROSS_ENCODER_MODELS = [ | |
| "antoinelouis/crossencoder-camemberta-L2-mmarcoFR", | |
| "antoinelouis/crossencoder-camemberta-L4-mmarcoFR", | |
| "antoinelouis/crossencoder-camemberta-L6-mmarcoFR", | |
| "antoinelouis/crossencoder-camemberta-L8-mmarcoFR", | |
| "antoinelouis/crossencoder-camemberta-L10-mmarcoFR", | |
| "antoinelouis/crossencoder-camemberta-base-mmarcoFR", | |
| "antoinelouis/crossencoder-camembert-L2-mmarcoFR", | |
| "antoinelouis/crossencoder-camembert-L4-mmarcoFR", | |
| "antoinelouis/crossencoder-camembert-L6-mmarcoFR", | |
| "antoinelouis/crossencoder-camembert-L8-mmarcoFR", | |
| "antoinelouis/crossencoder-camembert-L10-mmarcoFR", | |
| "antoinelouis/crossencoder-camembert-base-mmarcoFR", | |
| "antoinelouis/crossencoder-camembert-large-mmarcoFR", | |
| "antoinelouis/crossencoder-distilcamembert-mmarcoFR", | |
| "antoinelouis/crossencoder-electra-base-mmarcoFR", | |
| "antoinelouis/crossencoder-me5-base-mmarcoFR", | |
| "antoinelouis/crossencoder-me5-small-mmarcoFR", | |
| "antoinelouis/crossencoder-t5-base-mmarcoFR", | |
| "antoinelouis/crossencoder-t5-small-mmarcoFR", | |
| "antoinelouis/crossencoder-mt5-base-mmarcoFR", | |
| "antoinelouis/crossencoder-mt5-small-mmarcoFR", | |
| "antoinelouis/crossencoder-xlm-roberta-base-mmarcoFR", | |
| "antoinelouis/crossencoder-mdebertav3-base-mmarcoFR", | |
| "antoinelouis/crossencoder-mMiniLMv2-L12-mmarcoFR", | |
| "antoinelouis/crossencoder-mMiniLMv2-L6-mmarcoFR", | |
| ] | |
| COLUMNS = { | |
| "Model": "html", | |
| "#Params (M)": "number", | |
| "Type": "str", | |
| "Dataset": "str", | |
| "Recall@1000": "number", | |
| "Recall@500": "number", | |
| "Recall@100": "number", | |
| "Recall@10": "number", | |
| "MRR@10": "number", | |
| "nDCG@10": "number", | |
| "MAP@10": "number", | |
| } | |
| def get_model_info(model_id: str, model_type: str) -> pd.DataFrame: | |
| data = {} | |
| api = HfApi() | |
| model_info = api.model_info(model_id) | |
| for result in model_info.card_data.eval_results: | |
| if result.dataset_name in DATASETS and result.dataset_name not in data: | |
| data[result.dataset_name] = {key: None for key in COLUMNS.keys()} | |
| data[result.dataset_name]["Model"] = f'<a href="https://huggingface.co/{model_id}" target="_blank" style="color: blue; text-decoration: none;">{model_id}</a>' | |
| data[result.dataset_name]["#Params (M)"] = round(model_info.safetensors.total/1e6, 0) if model_info.safetensors else None | |
| data[result.dataset_name]["Type"] = model_type | |
| data[result.dataset_name]["Dataset"] = result.dataset_name | |
| if result.dataset_name in DATASETS and result.metric_name in data[result.dataset_name]: | |
| data[result.dataset_name][result.metric_name] = result.metric_value | |
| return pd.DataFrame(list(data.values())) | |
| def load_all_results() -> pd.DataFrame: | |
| # Load results from external baseline models. | |
| df = pd.read_csv('./baselines.csv') | |
| # Load results from own Hugging Face models. | |
| for model_id in SINGLE_VECTOR_MODELS: | |
| df = pd.concat([df, get_model_info(model_id, model_type="SINGLE")]) | |
| for model_id in MULTI_VECTOR_MODELS: | |
| df = pd.concat([df, get_model_info(model_id, model_type="MULTI")]) | |
| for model_id in SPARSE_LEXICAL_MODELS: | |
| df = pd.concat([df, get_model_info(model_id, model_type="SPARSE")]) | |
| for model_id in CROSS_ENCODER_MODELS: | |
| df = pd.concat([df, get_model_info(model_id, model_type="CROSS")]) | |
| # Round all metrics to 1 decimal. | |
| for col in df.columns: | |
| if "Recall" in col or "MRR" in col or "nDCG" in col or "MAP" in col: | |
| df[col] = df[col].round(1) | |
| return df | |
| def filter_dataf_by_dataset(dataf: pd.DataFrame, dataset_name: str, sort_by: str) -> pd.DataFrame: | |
| return (dataf | |
| .loc[dataf["Dataset"] == dataset_name] | |
| .drop(columns=["Dataset"]) | |
| .sort_values(by=sort_by, ascending=False) | |
| ) | |
| def update_table(dataf: pd.DataFrame, query: str, selected_types: list, selected_sizes: list) -> pd.DataFrame: | |
| filtered_df = dataf.copy() | |
| if selected_types: | |
| filtered_df = filtered_df[filtered_df['Type'].isin([t.split()[-1][1:-1] for t in selected_types])] | |
| size_conditions = [] | |
| for val in selected_sizes: | |
| if val == 'Small (< 100M)': | |
| size_conditions.append(filtered_df['#Params (M)'] < 100) | |
| elif val == 'Base (100M-300M)': | |
| size_conditions.append((filtered_df['#Params (M)'] >= 100) & (filtered_df['#Params (M)'] <= 300)) | |
| elif val == 'Large (300M-500M)': | |
| size_conditions.append((filtered_df['#Params (M)'] >= 300) & (filtered_df['#Params (M)'] <= 500)) | |
| elif val == 'Extra-large (500M+)': | |
| size_conditions.append(filtered_df['#Params (M)'] > 500) | |
| if size_conditions: | |
| filtered_df = filtered_df[pd.concat(size_conditions, axis=1).any(axis=1)] | |
| if query: | |
| filtered_df = filtered_df[filtered_df['Model'].str.contains(query, case=False)] | |
| return filtered_df | |
| with gr.Blocks() as demo: | |
| gr.HTML(""" | |
| <div style="display: flex; flex-direction: column; align-items: center;"> | |
| <div style="align-self: flex-start;"> | |
| <a href="mailto:[email protected]" target="_blank" style="color: blue; text-decoration: none;">Contact/Submissions</a> | |
| </div> | |
| <h1 style="margin: 0;">🥇 DécouvrIR\n</h1>A Benchmark for Evaluating the Robustness of Information Retrieval Models in French</h1> | |
| </div> | |
| """) | |
| # Create the Pandas dataframes (one per dataset) | |
| all_df = load_all_results() | |
| mmarco_df = filter_dataf_by_dataset(all_df, dataset_name="mMARCO-fr", sort_by="Recall@500") | |
| bsard_df = filter_dataf_by_dataset(all_df, dataset_name="BSARD", sort_by="Recall@500") | |
| # Search and filter widgets | |
| with gr.Column(): | |
| with gr.Row(): | |
| search_bar = gr.Textbox(placeholder=" 🔍 Search for a model...", show_label=False, elem_id="search-bar") | |
| with gr.Row(): | |
| filter_type = gr.CheckboxGroup( | |
| label="Model type", | |
| choices=[ | |
| 'Single-vector dense bi-encoder (SINGLE)', | |
| 'Multi-vector dense bi-encoder (MULTI)', | |
| 'Sparse lexical model (SPARSE)', | |
| 'Cross-encoder (CROSS)', | |
| ], | |
| value=[], | |
| interactive=True, | |
| elem_id="filter-type", | |
| ) | |
| with gr.Row(): | |
| filter_size = gr.CheckboxGroup( | |
| label="Model size", | |
| choices=['Small (< 100M)', 'Base (100M-300M)', 'Large (300M-500M)', 'Extra-large (500M+)'], | |
| value=[], | |
| interactive=True, | |
| elem_id="filter-size", | |
| ) | |
| # Leaderboard tables | |
| with gr.Tabs(): | |
| with gr.TabItem("🌐 mMARCO-fr"): | |
| gr.HTML(""" | |
| <p>The <a href="https://huggingface.co/datasets/unicamp-dl/mmarco" target="_blank" style="color: blue; text-decoration: none;">mMARCO</a> dataset is a machine-translated version of | |
| the widely popular MS MARCO dataset across 13 languages (including French) for studying <strong> domain-general</strong> passage retrieval.</p> | |
| <p>The evaluation is performed on <strong>6,980 dev questions</strong> labeled with relevant passages to be retrieved from a corpus of <strong>8,841,823 candidates</strong>.</p> | |
| """) | |
| mmarco_table = gr.Dataframe( | |
| value=mmarco_df, | |
| datatype=[COLUMNS[col] for col in mmarco_df.columns], | |
| interactive=False, | |
| elem_classes="text-sm", | |
| ) | |
| with gr.TabItem("⚖️ BSARD"): | |
| gr.HTML(""" | |
| <p>The <a href="https://huggingface.co/datasets/maastrichtlawtech/bsard" target="_blank" style="color: blue; text-decoration: none;">Belgian Statutory Article Retrieval Dataset (BSARD)</a> is a | |
| French native dataset for studying <strong>legal</strong> document retrieval.</p> | |
| <p>The evaluation is performed on <strong>222 test questions</strong> labeled by experienced jurists with relevant Belgian law articles to be retrieved from a corpus of <strong>22,633 candidates</strong>.</p> | |
| <i>[Coming soon...]</i> | |
| """) | |
| # bsard_table = gr.Dataframe( | |
| # value=bsard_df, | |
| # datatype=[COLUMNS[col] for col in bsard_df.columns], | |
| # interactive=False, | |
| # elem_classes="text-sm", | |
| # ) | |
| # Update tables on filter widgets change. | |
| widgets = [search_bar, filter_type, filter_size] | |
| for w in widgets: | |
| w.change(fn=lambda q, t, s: update_table(dataf=mmarco_df, query=q, selected_types=t, selected_sizes=s), inputs=widgets, outputs=[mmarco_table]) | |
| #w.change(fn=lambda q, t, s: update_table(dataf=bsard_df, query=q, selected_types=t, selected_sizes=s), inputs=widgets, outputs=[bsard_table]) | |
| # Citation | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.HTML(""" | |
| <h2>Citation</h2> | |
| <p>For attribution in academic contexts, please cite this benchmark and any of the models released by <a href="https://huggingface.co/antoinelouis" target="_blank" style="color: blue; text-decoration: none;">@antoinelouis</a> as follows:</p> | |
| """) | |
| with gr.Row(): | |
| citation_block = ( | |
| "@online{louis2024decouvrir,\n" | |
| "\tauthor = 'Antoine Louis',\n" | |
| "\ttitle = 'DécouvrIR: A Benchmark for Evaluating the Robustness of Information Retrieval Models in French',\n" | |
| "\tpublisher = 'Hugging Face',\n" | |
| "\tmonth = 'mar',\n" | |
| "\tyear = '2024',\n" | |
| "\turl = 'https://huggingface.co/spaces/antoinelouis/decouvrir',\n" | |
| "}\n" | |
| ) | |
| gr.Code(citation_block, language=None, show_label=False) | |
| demo.launch() |