Spaces:
Runtime error
Runtime error
| """ | |
| DatasetArgs Class | |
| ================= | |
| """ | |
| from dataclasses import dataclass | |
| import textattack | |
| from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file | |
| HUGGINGFACE_DATASET_BY_MODEL = { | |
| # | |
| # bert-base-uncased | |
| # | |
| "bert-base-uncased-ag-news": ("ag_news", None, "test"), | |
| "bert-base-uncased-cola": ("glue", "cola", "validation"), | |
| "bert-base-uncased-imdb": ("imdb", None, "test"), | |
| "bert-base-uncased-mnli": ( | |
| "glue", | |
| "mnli", | |
| "validation_matched", | |
| None, | |
| {0: 1, 1: 2, 2: 0}, | |
| ), | |
| "bert-base-uncased-mrpc": ("glue", "mrpc", "validation"), | |
| "bert-base-uncased-qnli": ("glue", "qnli", "validation"), | |
| "bert-base-uncased-qqp": ("glue", "qqp", "validation"), | |
| "bert-base-uncased-rte": ("glue", "rte", "validation"), | |
| "bert-base-uncased-sst2": ("glue", "sst2", "validation"), | |
| "bert-base-uncased-stsb": ( | |
| "glue", | |
| "stsb", | |
| "validation", | |
| None, | |
| None, | |
| None, | |
| 5.0, | |
| ), | |
| "bert-base-uncased-wnli": ("glue", "wnli", "validation"), | |
| "bert-base-uncased-mr": ("rotten_tomatoes", None, "test"), | |
| "bert-base-uncased-snli": ("snli", None, "test", None, {0: 1, 1: 2, 2: 0}), | |
| "bert-base-uncased-yelp": ("yelp_polarity", None, "test"), | |
| # | |
| # distilbert-base-cased | |
| # | |
| "distilbert-base-cased-cola": ("glue", "cola", "validation"), | |
| "distilbert-base-cased-mrpc": ("glue", "mrpc", "validation"), | |
| "distilbert-base-cased-qqp": ("glue", "qqp", "validation"), | |
| "distilbert-base-cased-snli": ("snli", None, "test"), | |
| "distilbert-base-cased-sst2": ("glue", "sst2", "validation"), | |
| "distilbert-base-cased-stsb": ( | |
| "glue", | |
| "stsb", | |
| "validation", | |
| None, | |
| None, | |
| None, | |
| 5.0, | |
| ), | |
| "distilbert-base-uncased-ag-news": ("ag_news", None, "test"), | |
| "distilbert-base-uncased-cola": ("glue", "cola", "validation"), | |
| "distilbert-base-uncased-imdb": ("imdb", None, "test"), | |
| "distilbert-base-uncased-mnli": ( | |
| "glue", | |
| "mnli", | |
| "validation_matched", | |
| None, | |
| {0: 1, 1: 2, 2: 0}, | |
| ), | |
| "distilbert-base-uncased-mr": ("rotten_tomatoes", None, "test"), | |
| "distilbert-base-uncased-mrpc": ("glue", "mrpc", "validation"), | |
| "distilbert-base-uncased-qnli": ("glue", "qnli", "validation"), | |
| "distilbert-base-uncased-rte": ("glue", "rte", "validation"), | |
| "distilbert-base-uncased-wnli": ("glue", "wnli", "validation"), | |
| # | |
| # roberta-base (RoBERTa is cased by default) | |
| # | |
| "roberta-base-ag-news": ("ag_news", None, "test"), | |
| "roberta-base-cola": ("glue", "cola", "validation"), | |
| "roberta-base-imdb": ("imdb", None, "test"), | |
| "roberta-base-mr": ("rotten_tomatoes", None, "test"), | |
| "roberta-base-mrpc": ("glue", "mrpc", "validation"), | |
| "roberta-base-qnli": ("glue", "qnli", "validation"), | |
| "roberta-base-rte": ("glue", "rte", "validation"), | |
| "roberta-base-sst2": ("glue", "sst2", "validation"), | |
| "roberta-base-stsb": ("glue", "stsb", "validation", None, None, None, 5.0), | |
| "roberta-base-wnli": ("glue", "wnli", "validation"), | |
| # | |
| # albert-base-v2 (ALBERT is cased by default) | |
| # | |
| "albert-base-v2-ag-news": ("ag_news", None, "test"), | |
| "albert-base-v2-cola": ("glue", "cola", "validation"), | |
| "albert-base-v2-imdb": ("imdb", None, "test"), | |
| "albert-base-v2-mr": ("rotten_tomatoes", None, "test"), | |
| "albert-base-v2-rte": ("glue", "rte", "validation"), | |
| "albert-base-v2-qqp": ("glue", "qqp", "validation"), | |
| "albert-base-v2-snli": ("snli", None, "test"), | |
| "albert-base-v2-sst2": ("glue", "sst2", "validation"), | |
| "albert-base-v2-stsb": ("glue", "stsb", "validation", None, None, None, 5.0), | |
| "albert-base-v2-wnli": ("glue", "wnli", "validation"), | |
| "albert-base-v2-yelp": ("yelp_polarity", None, "test"), | |
| # | |
| # xlnet-base-cased | |
| # | |
| "xlnet-base-cased-cola": ("glue", "cola", "validation"), | |
| "xlnet-base-cased-imdb": ("imdb", None, "test"), | |
| "xlnet-base-cased-mr": ("rotten_tomatoes", None, "test"), | |
| "xlnet-base-cased-mrpc": ("glue", "mrpc", "validation"), | |
| "xlnet-base-cased-rte": ("glue", "rte", "validation"), | |
| "xlnet-base-cased-stsb": ( | |
| "glue", | |
| "stsb", | |
| "validation", | |
| None, | |
| None, | |
| None, | |
| 5.0, | |
| ), | |
| "xlnet-base-cased-wnli": ("glue", "wnli", "validation"), | |
| } | |
| # | |
| # Models hosted by textattack. | |
| # | |
| TEXTATTACK_DATASET_BY_MODEL = { | |
| # | |
| # LSTMs | |
| # | |
| "lstm-ag-news": ("ag_news", None, "test"), | |
| "lstm-imdb": ("imdb", None, "test"), | |
| "lstm-mr": ("rotten_tomatoes", None, "test"), | |
| "lstm-sst2": ("glue", "sst2", "validation"), | |
| "lstm-yelp": ("yelp_polarity", None, "test"), | |
| # | |
| # CNNs | |
| # | |
| "cnn-ag-news": ("ag_news", None, "test"), | |
| "cnn-imdb": ("imdb", None, "test"), | |
| "cnn-mr": ("rotten_tomatoes", None, "test"), | |
| "cnn-sst2": ("glue", "sst2", "validation"), | |
| "cnn-yelp": ("yelp_polarity", None, "test"), | |
| # | |
| # T5 for translation | |
| # | |
| "t5-en-de": ( | |
| "textattack.datasets.helpers.TedMultiTranslationDataset", | |
| "en", | |
| "de", | |
| ), | |
| "t5-en-fr": ( | |
| "textattack.datasets.helpers.TedMultiTranslationDataset", | |
| "en", | |
| "fr", | |
| ), | |
| "t5-en-ro": ( | |
| "textattack.datasets.helpers.TedMultiTranslationDataset", | |
| "en", | |
| "de", | |
| ), | |
| # | |
| # T5 for summarization | |
| # | |
| "t5-summarization": ("gigaword", None, "test"), | |
| } | |
| class DatasetArgs: | |
| """Arguments for loading dataset from command line input.""" | |
| dataset_by_model: str = None | |
| dataset_from_huggingface: str = None | |
| dataset_from_file: str = None | |
| dataset_split: str = None | |
| filter_by_labels: list = None | |
| def _add_parser_args(cls, parser): | |
| """Adds dataset-related arguments to an argparser.""" | |
| dataset_group = parser.add_mutually_exclusive_group() | |
| dataset_group.add_argument( | |
| "--dataset-by-model", | |
| type=str, | |
| required=False, | |
| default=None, | |
| help="Dataset to load depending on the name of the model", | |
| ) | |
| dataset_group.add_argument( | |
| "--dataset-from-huggingface", | |
| type=str, | |
| required=False, | |
| default=None, | |
| help="Dataset to load from `datasets` repository.", | |
| ) | |
| dataset_group.add_argument( | |
| "--dataset-from-file", | |
| type=str, | |
| required=False, | |
| default=None, | |
| help="Dataset to load from a file.", | |
| ) | |
| parser.add_argument( | |
| "--dataset-split", | |
| type=str, | |
| required=False, | |
| default=None, | |
| help="Split of dataset to use when specifying --dataset-by-model or --dataset-from-huggingface.", | |
| ) | |
| parser.add_argument( | |
| "--filter-by-labels", | |
| nargs="+", | |
| type=int, | |
| required=False, | |
| default=None, | |
| help="List of labels to keep in the dataset and discard all others.", | |
| ) | |
| return parser | |
| def _create_dataset_from_args(cls, args): | |
| """Given ``DatasetArgs``, return specified | |
| ``textattack.dataset.Dataset`` object.""" | |
| assert isinstance( | |
| args, cls | |
| ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`." | |
| # Automatically detect dataset for huggingface & textattack models. | |
| # This allows us to use the --model shortcut without specifying a dataset. | |
| if hasattr(args, "model"): | |
| args.dataset_by_model = args.model | |
| if args.dataset_by_model in HUGGINGFACE_DATASET_BY_MODEL: | |
| args.dataset_from_huggingface = HUGGINGFACE_DATASET_BY_MODEL[ | |
| args.dataset_by_model | |
| ] | |
| elif args.dataset_by_model in TEXTATTACK_DATASET_BY_MODEL: | |
| dataset = TEXTATTACK_DATASET_BY_MODEL[args.dataset_by_model] | |
| if dataset[0].startswith("textattack"): | |
| # unsavory way to pass custom dataset classes | |
| # ex: dataset = ('textattack.datasets.helpers.TedMultiTranslationDataset', 'en', 'de') | |
| dataset = eval(f"{dataset[0]}")(*dataset[1:]) | |
| return dataset | |
| else: | |
| args.dataset_from_huggingface = dataset | |
| # Get dataset from args. | |
| if args.dataset_from_file: | |
| textattack.shared.logger.info( | |
| f"Loading model and tokenizer from file: {args.model_from_file}" | |
| ) | |
| if ARGS_SPLIT_TOKEN in args.dataset_from_file: | |
| dataset_file, dataset_name = args.dataset_from_file.split( | |
| ARGS_SPLIT_TOKEN | |
| ) | |
| else: | |
| dataset_file, dataset_name = args.dataset_from_file, "dataset" | |
| try: | |
| dataset_module = load_module_from_file(dataset_file) | |
| except Exception: | |
| raise ValueError(f"Failed to import file {args.dataset_from_file}") | |
| try: | |
| dataset = getattr(dataset_module, dataset_name) | |
| except AttributeError: | |
| raise AttributeError( | |
| f"Variable ``dataset`` not found in module {args.dataset_from_file}" | |
| ) | |
| elif args.dataset_from_huggingface: | |
| dataset_args = args.dataset_from_huggingface | |
| if isinstance(dataset_args, str): | |
| if ARGS_SPLIT_TOKEN in dataset_args: | |
| dataset_args = dataset_args.split(ARGS_SPLIT_TOKEN) | |
| else: | |
| dataset_args = (dataset_args,) | |
| if args.dataset_split: | |
| if len(dataset_args) > 1: | |
| dataset_args = ( | |
| dataset_args[:2] + (args.dataset_split,) + dataset_args[3:] | |
| ) | |
| dataset = textattack.datasets.HuggingFaceDataset( | |
| *dataset_args, shuffle=False | |
| ) | |
| else: | |
| dataset = textattack.datasets.HuggingFaceDataset( | |
| *dataset_args, split=args.dataset_split, shuffle=False | |
| ) | |
| else: | |
| dataset = textattack.datasets.HuggingFaceDataset( | |
| *dataset_args, shuffle=False | |
| ) | |
| else: | |
| raise ValueError("Must supply pretrained model or dataset") | |
| assert isinstance( | |
| dataset, textattack.datasets.Dataset | |
| ), "Loaded `dataset` must be of type `textattack.datasets.Dataset`." | |
| if args.filter_by_labels: | |
| dataset.filter_by_labels_(args.filter_by_labels) | |
| return dataset | |