Spaces:
Runtime error
Runtime error
| """ | |
| AttackResumeCommand class | |
| =========================== | |
| """ | |
| from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | |
| import os | |
| import textattack | |
| from textattack import Attacker, CommandLineAttackArgs, DatasetArgs, ModelArgs | |
| from textattack.commands import TextAttackCommand | |
| class AttackResumeCommand(TextAttackCommand): | |
| """The TextAttack attack resume recipe module: | |
| A command line parser to resume a checkpointed attack from user | |
| specifications. | |
| """ | |
| def run(self, args): | |
| checkpoint = self._parse_checkpoint_from_args(args) | |
| assert isinstance(checkpoint.attack_args, CommandLineAttackArgs), ( | |
| f"Expect `attack_args` to be of type `textattack.args.CommandLineAttackArgs`, but got type `{type(checkpoint.attack_args)}`. " | |
| f"If saved `attack_args` is not of type `textattack.args.CommandLineAttackArgs`, cannot resume attack from command line." | |
| ) | |
| # merge/update arguments | |
| checkpoint.attack_args.parallel = args.parallel | |
| if args.checkpoint_dir: | |
| checkpoint.attack_args.checkpoint_dir = args.checkpoint_dir | |
| if args.checkpoint_interval: | |
| checkpoint.attack_args.checkpoint_interval = args.checkpoint_interval | |
| model_wrapper = ModelArgs._create_model_from_args( | |
| checkpoint.attack_args.attack_args | |
| ) | |
| attack = CommandLineAttackArgs._create_attack_from_args( | |
| checkpoint.attack_args, model_wrapper | |
| ) | |
| dataset = DatasetArgs.parse_dataset_from_args(checkpoint.attack_args) | |
| attacker = Attacker.from_checkpoint(attack, dataset, checkpoint) | |
| attacker.attack_dataset() | |
| def _parse_checkpoint_from_args(self, args): | |
| file_name = os.path.basename(args.checkpoint_file) | |
| if file_name.lower() == "latest": | |
| dir_path = os.path.dirname(args.checkpoint_file) | |
| dir_path = dir_path if dir_path else "." | |
| chkpt_file_names = [ | |
| f for f in os.listdir(dir_path) if f.endswith(".ta.chkpt") | |
| ] | |
| assert chkpt_file_names, "AttackCheckpoint directory is empty" | |
| timestamps = [int(f.replace(".ta.chkpt", "")) for f in chkpt_file_names] | |
| latest_file = str(max(timestamps)) + ".ta.chkpt" | |
| checkpoint_path = os.path.join(dir_path, latest_file) | |
| else: | |
| checkpoint_path = args.checkpoint_file | |
| checkpoint = textattack.shared.AttackCheckpoint.load(checkpoint_path) | |
| return checkpoint | |
| def register_subcommand(main_parser: ArgumentParser): | |
| resume_parser = main_parser.add_parser( | |
| "attack-resume", | |
| help="resume a checkpointed attack", | |
| formatter_class=ArgumentDefaultsHelpFormatter, | |
| ) | |
| # Parser for parsing args for resume | |
| resume_parser.add_argument( | |
| "--checkpoint-file", | |
| "-f", | |
| type=str, | |
| required=True, | |
| help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,' | |
| "recover latest checkpoint from either current path or specified directory.", | |
| ) | |
| resume_parser.add_argument( | |
| "--checkpoint-dir", | |
| "-d", | |
| required=False, | |
| type=str, | |
| default=None, | |
| help="The directory to save checkpoint files. If not set, use directory from recovered arguments.", | |
| ) | |
| resume_parser.add_argument( | |
| "--checkpoint-interval", | |
| "-i", | |
| required=False, | |
| type=int, | |
| help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.", | |
| ) | |
| resume_parser.add_argument( | |
| "--parallel", | |
| action="store_true", | |
| default=False, | |
| help="Run attack using multiple GPUs.", | |
| ) | |
| resume_parser.set_defaults(func=AttackResumeCommand()) | |