Spaces:
Runtime error
Runtime error
| """ | |
| ListThingsCommand class | |
| ============================== | |
| """ | |
| from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | |
| import textattack | |
| from textattack.attack_args import ( | |
| ATTACK_RECIPE_NAMES, | |
| BLACK_BOX_TRANSFORMATION_CLASS_NAMES, | |
| CONSTRAINT_CLASS_NAMES, | |
| GOAL_FUNCTION_CLASS_NAMES, | |
| SEARCH_METHOD_CLASS_NAMES, | |
| WHITE_BOX_TRANSFORMATION_CLASS_NAMES, | |
| ) | |
| from textattack.augment_args import AUGMENTATION_RECIPE_NAMES | |
| from textattack.commands import TextAttackCommand | |
| from textattack.model_args import HUGGINGFACE_MODELS, TEXTATTACK_MODELS | |
| def _cb(s): | |
| return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") | |
| class ListThingsCommand(TextAttackCommand): | |
| """The list module: | |
| List default things in textattack. | |
| """ | |
| def _list(self, list_of_things, plain=False): | |
| """Prints a list or dict of things.""" | |
| if isinstance(list_of_things, list): | |
| list_of_things = sorted(list_of_things) | |
| for thing in list_of_things: | |
| if plain: | |
| print(thing) | |
| else: | |
| print(_cb(thing)) | |
| elif isinstance(list_of_things, dict): | |
| for thing in sorted(list_of_things.keys()): | |
| thing_long_description = list_of_things[thing] | |
| if plain: | |
| thing_key = thing | |
| else: | |
| thing_key = _cb(thing) | |
| print(f"{thing_key} ({thing_long_description})") | |
| else: | |
| raise TypeError(f"Cannot print list of type {type(list_of_things)}") | |
| def things(): | |
| list_dict = {} | |
| list_dict["models"] = list(HUGGINGFACE_MODELS.keys()) + list( | |
| TEXTATTACK_MODELS.keys() | |
| ) | |
| list_dict["search-methods"] = SEARCH_METHOD_CLASS_NAMES | |
| list_dict["transformations"] = { | |
| **BLACK_BOX_TRANSFORMATION_CLASS_NAMES, | |
| **WHITE_BOX_TRANSFORMATION_CLASS_NAMES, | |
| } | |
| list_dict["constraints"] = CONSTRAINT_CLASS_NAMES | |
| list_dict["goal-functions"] = GOAL_FUNCTION_CLASS_NAMES | |
| list_dict["attack-recipes"] = ATTACK_RECIPE_NAMES | |
| list_dict["augmentation-recipes"] = AUGMENTATION_RECIPE_NAMES | |
| return list_dict | |
| def run(self, args): | |
| try: | |
| list_of_things = ListThingsCommand.things()[args.feature] | |
| except KeyError: | |
| raise ValueError(f"Unknown list key {args.thing}") | |
| self._list(list_of_things, plain=args.plain) | |
| def register_subcommand(main_parser: ArgumentParser): | |
| parser = main_parser.add_parser( | |
| "list", | |
| help="list features in TextAttack", | |
| formatter_class=ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "feature", help="the feature to list", choices=ListThingsCommand.things() | |
| ) | |
| parser.add_argument( | |
| "--plain", | |
| help="print output without color", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.set_defaults(func=ListThingsCommand()) | |