Spaces:
Runtime error
Runtime error
| """ | |
| Misc Checkpoints | |
| =================== | |
| The ``AttackCheckpoint`` class saves in-progress attacks and loads saved attacks from disk. | |
| """ | |
| import copy | |
| import datetime | |
| import os | |
| import pickle | |
| import time | |
| import textattack | |
| from textattack.attack_results import ( | |
| FailedAttackResult, | |
| MaximizedAttackResult, | |
| SkippedAttackResult, | |
| SuccessfulAttackResult, | |
| ) | |
| from textattack.shared import logger, utils | |
| # TODO: Consider still keeping the old `Checkpoint` class and allow older checkpoints to be loaded to new TextAttack | |
| class AttackCheckpoint: | |
| """An object that stores necessary information for saving and loading | |
| checkpoints. | |
| Args: | |
| attack_args (textattack.AttackArgs): Arguments of the original attack | |
| attack_log_manager (textattack.loggers.AttackLogManager): Object for storing attack results | |
| worklist (deque[int]): List of examples that will be attacked. Examples are represented by their indicies within the dataset. | |
| worklist_candidates (int): List of other available examples we can attack. Used to get the next dataset element when `attack_n=True`. | |
| chkpt_time (float): epoch time representing when checkpoint was made | |
| """ | |
| def __init__( | |
| self, | |
| attack_args, | |
| attack_log_manager, | |
| worklist, | |
| worklist_candidates, | |
| chkpt_time=None, | |
| ): | |
| assert isinstance( | |
| attack_args, textattack.AttackArgs | |
| ), "`attack_args` must be of type `textattack.AttackArgs`." | |
| assert isinstance( | |
| attack_log_manager, textattack.loggers.AttackLogManager | |
| ), "`attack_log_manager` must be of type `textattack.loggers.AttackLogManager`." | |
| self.attack_args = copy.deepcopy(attack_args) | |
| self.attack_log_manager = attack_log_manager | |
| self.worklist = worklist | |
| self.worklist_candidates = worklist_candidates | |
| if chkpt_time: | |
| self.time = chkpt_time | |
| else: | |
| self.time = time.time() | |
| self._verify() | |
| def __repr__(self): | |
| main_str = "AttackCheckpoint(" | |
| lines = [] | |
| lines.append(utils.add_indent(f"(Time): {self.datetime}", 2)) | |
| args_lines = [] | |
| recipe_set = ( | |
| True | |
| if "recipe" in self.attack_args.__dict__ | |
| and self.attack_args.__dict__["recipe"] | |
| else False | |
| ) | |
| mutually_exclusive_args = ["search", "transformation", "constraints", "recipe"] | |
| if recipe_set: | |
| args_lines.append( | |
| utils.add_indent(f'(recipe): {self.attack_args.__dict__["recipe"]}', 2) | |
| ) | |
| else: | |
| args_lines.append( | |
| utils.add_indent(f'(search): {self.attack_args.__dict__["search"]}', 2) | |
| ) | |
| args_lines.append( | |
| utils.add_indent( | |
| f'(transformation): {self.attack_args.__dict__["transformation"]}', | |
| 2, | |
| ) | |
| ) | |
| args_lines.append( | |
| utils.add_indent( | |
| f'(constraints): {self.attack_args.__dict__["constraints"]}', 2 | |
| ) | |
| ) | |
| for key in self.attack_args.__dict__: | |
| if key not in mutually_exclusive_args: | |
| args_lines.append( | |
| utils.add_indent(f"({key}): {self.attack_args.__dict__[key]}", 2) | |
| ) | |
| args_str = utils.add_indent("\n" + "\n".join(args_lines), 2) | |
| lines.append(utils.add_indent(f"(attack_args): {args_str}", 2)) | |
| attack_logger_lines = [] | |
| attack_logger_lines.append( | |
| utils.add_indent( | |
| f"(Total number of examples to attack): {self.attack_args.num_examples}", | |
| 2, | |
| ) | |
| ) | |
| attack_logger_lines.append( | |
| utils.add_indent(f"(Number of attacks performed): {self.results_count}", 2) | |
| ) | |
| attack_logger_lines.append( | |
| utils.add_indent( | |
| f"(Number of remaining attacks): {self.num_remaining_attacks}", 2 | |
| ) | |
| ) | |
| breakdown_lines = [] | |
| breakdown_lines.append( | |
| utils.add_indent( | |
| f"(Number of successful attacks): {self.num_successful_attacks}", 2 | |
| ) | |
| ) | |
| breakdown_lines.append( | |
| utils.add_indent( | |
| f"(Number of failed attacks): {self.num_failed_attacks}", 2 | |
| ) | |
| ) | |
| breakdown_lines.append( | |
| utils.add_indent( | |
| f"(Number of maximized attacks): {self.num_maximized_attacks}", 2 | |
| ) | |
| ) | |
| breakdown_lines.append( | |
| utils.add_indent( | |
| f"(Number of skipped attacks): {self.num_skipped_attacks}", 2 | |
| ) | |
| ) | |
| breakdown_str = utils.add_indent("\n" + "\n".join(breakdown_lines), 2) | |
| attack_logger_lines.append( | |
| utils.add_indent(f"(Latest result breakdown): {breakdown_str}", 2) | |
| ) | |
| attack_logger_str = utils.add_indent("\n" + "\n".join(attack_logger_lines), 2) | |
| lines.append( | |
| utils.add_indent(f"(Previous attack summary): {attack_logger_str}", 2) | |
| ) | |
| main_str += "\n " + "\n ".join(lines) + "\n" | |
| main_str += ")" | |
| return main_str | |
| __str__ = __repr__ | |
| def results_count(self): | |
| """Return number of attacks made so far.""" | |
| return len(self.attack_log_manager.results) | |
| def num_skipped_attacks(self): | |
| return sum( | |
| isinstance(r, SkippedAttackResult) for r in self.attack_log_manager.results | |
| ) | |
| def num_failed_attacks(self): | |
| return sum( | |
| isinstance(r, FailedAttackResult) for r in self.attack_log_manager.results | |
| ) | |
| def num_successful_attacks(self): | |
| return sum( | |
| isinstance(r, SuccessfulAttackResult) | |
| for r in self.attack_log_manager.results | |
| ) | |
| def num_maximized_attacks(self): | |
| return sum( | |
| isinstance(r, MaximizedAttackResult) | |
| for r in self.attack_log_manager.results | |
| ) | |
| def num_remaining_attacks(self): | |
| if self.attack_args.attack_n: | |
| non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks | |
| count = self.attack_args.num_examples - non_skipped_attacks | |
| else: | |
| count = self.attack_args.num_examples - self.results_count | |
| return count | |
| def dataset_offset(self): | |
| """Calculate offset into the dataset to start from.""" | |
| # Original offset + # of results processed so far | |
| return self.attack_args.num_examples_offset + self.results_count | |
| def datetime(self): | |
| return datetime.datetime.fromtimestamp(self.time).strftime("%Y-%m-%d %H:%M:%S") | |
| def save(self, quiet=False): | |
| file_name = "{}.ta.chkpt".format(int(self.time * 1000)) | |
| if not os.path.exists(self.attack_args.checkpoint_dir): | |
| os.makedirs(self.attack_args.checkpoint_dir) | |
| path = os.path.join(self.attack_args.checkpoint_dir, file_name) | |
| if not quiet: | |
| print("\n\n" + "=" * 125) | |
| logger.info( | |
| 'Saving checkpoint under "{}" at {} after {} attacks.'.format( | |
| path, self.datetime, self.results_count | |
| ) | |
| ) | |
| print("=" * 125 + "\n") | |
| with open(path, "wb") as f: | |
| pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| def load(cls, path): | |
| with open(path, "rb") as f: | |
| checkpoint = pickle.load(f) | |
| assert isinstance(checkpoint, cls) | |
| return checkpoint | |
| def _verify(self): | |
| """Check that the checkpoint has no duplicates and is consistent.""" | |
| assert self.num_remaining_attacks == len( | |
| self.worklist | |
| ), "Recorded number of remaining attacks and size of worklist are different." | |
| results_set = { | |
| result.original_text for result in self.attack_log_manager.results | |
| } | |
| assert ( | |
| len(results_set) == self.results_count | |
| ), "Duplicate `AttackResults` found." | |