Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import disable_caching, load_dataset | |
| from transformer_ranker import TransformerRanker | |
| from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS | |
| from demo.utils import ( | |
| BANNER, | |
| FOOTER, | |
| CSS, | |
| UNSET, | |
| EmbeddingProgressTracker, | |
| compute_ratio, | |
| validate_dataset, | |
| preprocess_dataset, | |
| ensure_dataset_is_loaded, | |
| ) | |
| disable_caching() | |
| with gr.Blocks(css=CSS, theme=None) as demo: | |
| gr.Markdown(BANNER, elem_classes="banner") | |
| ##### 1. Load from datasets ##### | |
| gr.Markdown("## 📚 Load Data") | |
| gr.Markdown( | |
| "Pick a dataset from the Hugging Face Hub (e.g. `trec`). This defines your downstream task." | |
| ) | |
| with gr.Group(): | |
| dataset = gr.State(None) | |
| dataset_id = gr.Textbox( | |
| label="Dataset identifier", | |
| placeholder="try: trec, conll2003, ag_news", | |
| max_lines=1, | |
| ) | |
| load_dataset_button = gr.Button( | |
| value="Load data", | |
| variant="primary", | |
| interactive=True, | |
| ) | |
| # enable loading if dataset exists on hub | |
| dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button) | |
| gr.Markdown( | |
| "⚡️ Speed mode on: tweak the downsampling ratio in *Dataset Setup* for quicker runs. " | |
| "Unlock the full data via [framework](https://github.com/flairNLP/transformer-ranker)." | |
| ) | |
| ##### data preprocessing ##### | |
| with gr.Accordion("Dataset Setup", open=False) as dataset_config: | |
| with gr.Row() as dataset_details: | |
| dataset_id_label = gr.Label("", label="Dataset") | |
| num_samples = gr.State(0) | |
| num_samples_label = gr.Label("", label="Dataset size") | |
| num_samples.change(lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]) | |
| with gr.Row(): | |
| text_column = gr.Dropdown("", label="Text Column") | |
| text_pair_column = gr.Dropdown("", label="Text Pair") | |
| with gr.Row(): | |
| label_column = gr.Dropdown("", label="Labels") | |
| task_category = gr.Dropdown("", label="Downstream Task") | |
| with gr.Group(): | |
| downsample_ratio = gr.State(0.0) | |
| sampling_rate = gr.Slider(20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1) | |
| downsample_ratio_label = gr.Label("", label="Sampling rate") | |
| downsample_ratio.change( | |
| lambda x: f"{x:.1%}", | |
| inputs=[downsample_ratio], | |
| outputs=[downsample_ratio_label], | |
| ) | |
| sampling_rate.change( | |
| compute_ratio, | |
| inputs=[sampling_rate, num_samples], | |
| outputs=downsample_ratio, | |
| ) | |
| num_samples.change( | |
| compute_ratio, | |
| inputs=[sampling_rate, num_samples], | |
| outputs=downsample_ratio, | |
| ) | |
| def load_hf_dataset(dataset_id): | |
| try: | |
| dataset = load_dataset(dataset_id, trust_remote_code=True) | |
| dataset_details = preprocess_dataset(dataset) | |
| except ValueError as e: | |
| gr.Warning(f"Watch out — single datasets only. Cannot load dataset: {e}") | |
| return (gr.update(value="Loaded"), dataset_id, dataset, *dataset_details) | |
| load_dataset_button.click( | |
| load_hf_dataset, | |
| inputs=[dataset_id], | |
| outputs=[ | |
| load_dataset_button, | |
| dataset_id_label, | |
| dataset, | |
| task_category, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| num_samples, | |
| ], | |
| scroll_to_output=True, | |
| ) | |
| ########## 2. Select LMs ########## | |
| gr.Markdown("## 🧠 Select Language Models") | |
| gr.Markdown( | |
| "Add two or more pretrained models to compare. " | |
| "Stick to smaller models here since the demo runs on CPU." | |
| ) | |
| with gr.Group(): | |
| model_options = [(model_handle.split("/")[-1], model_handle) for model_handle in ALL_LMS] | |
| models = gr.CheckboxGroup(choices=model_options, label="Model List", value=PRESELECTED_LMS) | |
| ########## 3. Run ranking ########## | |
| gr.Markdown("## 🏆 Rank Models") | |
| gr.Markdown( | |
| "Rank models by transferability to your task. " | |
| "More control? Tweak transferability metric and layer aggregation in *Settings*." | |
| ) | |
| with gr.Group(): | |
| submit_button = gr.Button("Run ranking", variant="primary", interactive=False) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| estimator = gr.Dropdown( | |
| choices=["hscore", "logme", "knn"], | |
| label="Transferability metric", | |
| value="hscore", | |
| ) | |
| layer_aggregator = gr.Dropdown( | |
| choices=["lastlayer", "layermean", "bestlayer"], | |
| label="Layer aggregation", | |
| value="layermean", | |
| ) | |
| # ranking button works after dataset loads | |
| dataset.change( | |
| ensure_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| label_column.change( | |
| ensure_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| text_column.change( | |
| ensure_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| def rank_models( | |
| dataset, | |
| downsample_ratio, | |
| selected_models, | |
| layer_aggregator, | |
| estimator, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| task_category, | |
| progress=gr.Progress(), | |
| ): | |
| if text_column == UNSET: | |
| raise gr.Error("Text column is required.") | |
| if label_column == UNSET: | |
| raise gr.Error("Label column is required.") | |
| if task_category == UNSET: | |
| raise gr.Error("Task category is required.") | |
| if text_pair_column == UNSET: | |
| text_pair_column = None | |
| progress(0.0, "Starting") | |
| with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker: | |
| try: | |
| ranker = TransformerRanker( | |
| dataset, | |
| dataset_downsample=downsample_ratio, | |
| text_column=text_column, | |
| text_pair_column=text_pair_column, | |
| label_column=label_column, | |
| task_category=task_category, | |
| ) | |
| results = ranker.run( | |
| models=selected_models, | |
| layer_aggregator=layer_aggregator, | |
| estimator=estimator, | |
| batch_size=64, | |
| tracker=tracker, | |
| ) | |
| sorted_results = sorted(results._results.items(), key=lambda item: item[1], reverse=True) | |
| return [(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)] | |
| except Exception as e: | |
| gr.Warning(f"Ranking issue: {e}") | |
| return [] | |
| gr.Markdown("**Leaderboard:** higher score → better downstream performance.") | |
| ranking_results = gr.Dataframe( | |
| headers=["Rank", "Model", "Score"], | |
| datatype=["number", "str", "number"], | |
| value=[["-", "-", "-"]], | |
| interactive=False | |
| ) | |
| submit_button.click( | |
| rank_models, | |
| inputs=[ | |
| dataset, | |
| downsample_ratio, | |
| models, | |
| layer_aggregator, | |
| estimator, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| task_category, | |
| ], | |
| outputs=ranking_results, | |
| scroll_to_output=True, | |
| ) | |
| gr.Markdown(FOOTER) | |