Spaces:
Runtime error
Runtime error
| """ | |
| Attack Class | |
| ============ | |
| """ | |
| from collections import OrderedDict | |
| from typing import List, Union | |
| import lru | |
| import torch | |
| import textattack | |
| from textattack.attack_results import ( | |
| FailedAttackResult, | |
| MaximizedAttackResult, | |
| SkippedAttackResult, | |
| SuccessfulAttackResult, | |
| ) | |
| from textattack.constraints import Constraint, PreTransformationConstraint | |
| from textattack.goal_function_results import GoalFunctionResultStatus | |
| from textattack.goal_functions import GoalFunction | |
| from textattack.models.wrappers import ModelWrapper | |
| from textattack.search_methods import SearchMethod | |
| from textattack.shared import AttackedText, utils | |
| from textattack.transformations import CompositeTransformation, Transformation | |
| class Attack: | |
| """An attack generates adversarial examples on text. | |
| An attack is comprised of a goal function, constraints, transformation, and a search method. Use :meth:`attack` method to attack one sample at a time. | |
| Args: | |
| goal_function (:class:`~textattack.goal_functions.GoalFunction`): | |
| A function for determining how well a perturbation is doing at achieving the attack's goal. | |
| constraints (list of :class:`~textattack.constraints.Constraint` or :class:`~textattack.constraints.PreTransformationConstraint`): | |
| A list of constraints to add to the attack, defining which perturbations are valid. | |
| transformation (:class:`~textattack.transformations.Transformation`): | |
| The transformation applied at each step of the attack. | |
| search_method (:class:`~textattack.search_methods.SearchMethod`): | |
| The method for exploring the search space of possible perturbations | |
| transformation_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`): | |
| The number of items to keep in the transformations cache | |
| constraint_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`): | |
| The number of items to keep in the constraints cache | |
| Example:: | |
| >>> import textattack | |
| >>> import transformers | |
| >>> # Load model, tokenizer, and model_wrapper | |
| >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") | |
| >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") | |
| >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
| >>> # Construct our four components for `Attack` | |
| >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification | |
| >>> from textattack.constraints.semantics import WordEmbeddingDistance | |
| >>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) | |
| >>> constraints = [ | |
| ... RepeatModification(), | |
| ... StopwordModification() | |
| ... WordEmbeddingDistance(min_cos_sim=0.9) | |
| ... ] | |
| >>> transformation = WordSwapEmbedding(max_candidates=50) | |
| >>> search_method = GreedyWordSwapWIR(wir_method="delete") | |
| >>> # Construct the actual attack | |
| >>> attack = Attack(goal_function, constraints, transformation, search_method) | |
| >>> input_text = "I really enjoyed the new movie that came out last month." | |
| >>> label = 1 #Positive | |
| >>> attack_result = attack.attack(input_text, label) | |
| """ | |
| def __init__( | |
| self, | |
| goal_function: GoalFunction, | |
| constraints: List[Union[Constraint, PreTransformationConstraint]], | |
| transformation: Transformation, | |
| search_method: SearchMethod, | |
| transformation_cache_size=2**15, | |
| constraint_cache_size=2**15, | |
| ): | |
| """Initialize an attack object. | |
| Attacks can be run multiple times. | |
| """ | |
| assert isinstance( | |
| goal_function, GoalFunction | |
| ), f"`goal_function` must be of type `textattack.goal_functions.GoalFunction`, but got type `{type(goal_function)}`." | |
| assert isinstance( | |
| constraints, list | |
| ), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`." | |
| for c in constraints: | |
| assert isinstance( | |
| c, (Constraint, PreTransformationConstraint) | |
| ), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`." | |
| assert isinstance( | |
| transformation, Transformation | |
| ), f"`transformation` must be of type `textattack.transformations.Transformation`, but got type `{type(transformation)}`." | |
| assert isinstance( | |
| search_method, SearchMethod | |
| ), f"`search_method` must be of type `textattack.search_methods.SearchMethod`, but got type `{type(search_method)}`." | |
| self.goal_function = goal_function | |
| self.search_method = search_method | |
| self.transformation = transformation | |
| self.is_black_box = ( | |
| getattr(transformation, "is_black_box", True) and search_method.is_black_box | |
| ) | |
| if not self.search_method.check_transformation_compatibility( | |
| self.transformation | |
| ): | |
| raise ValueError( | |
| f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}" | |
| ) | |
| self.constraints = [] | |
| self.pre_transformation_constraints = [] | |
| for constraint in constraints: | |
| if isinstance( | |
| constraint, | |
| textattack.constraints.PreTransformationConstraint, | |
| ): | |
| self.pre_transformation_constraints.append(constraint) | |
| else: | |
| self.constraints.append(constraint) | |
| # Check if we can use transformation cache for our transformation. | |
| if not self.transformation.deterministic: | |
| self.use_transformation_cache = False | |
| elif isinstance(self.transformation, CompositeTransformation): | |
| self.use_transformation_cache = True | |
| for t in self.transformation.transformations: | |
| if not t.deterministic: | |
| self.use_transformation_cache = False | |
| break | |
| else: | |
| self.use_transformation_cache = True | |
| self.transformation_cache_size = transformation_cache_size | |
| self.transformation_cache = lru.LRU(transformation_cache_size) | |
| self.constraint_cache_size = constraint_cache_size | |
| self.constraints_cache = lru.LRU(constraint_cache_size) | |
| # Give search method access to functions for getting transformations and evaluating them | |
| self.search_method.get_transformations = self.get_transformations | |
| # Give search method access to self.goal_function for model query count, etc. | |
| self.search_method.goal_function = self.goal_function | |
| # The search method only needs access to the first argument. The second is only used | |
| # by the attack class when checking whether to skip the sample | |
| self.search_method.get_goal_results = self.goal_function.get_results | |
| # Give search method access to get indices which need to be ordered / searched | |
| self.search_method.get_indices_to_order = self.get_indices_to_order | |
| self.search_method.filter_transformations = self.filter_transformations | |
| def clear_cache(self, recursive=True): | |
| self.constraints_cache.clear() | |
| if self.use_transformation_cache: | |
| self.transformation_cache.clear() | |
| if recursive: | |
| self.goal_function.clear_cache() | |
| for constraint in self.constraints: | |
| if hasattr(constraint, "clear_cache"): | |
| constraint.clear_cache() | |
| def cpu_(self): | |
| """Move any `torch.nn.Module` models that are part of Attack to CPU.""" | |
| visited = set() | |
| def to_cpu(obj): | |
| visited.add(id(obj)) | |
| if isinstance(obj, torch.nn.Module): | |
| obj.cpu() | |
| elif isinstance( | |
| obj, | |
| ( | |
| Attack, | |
| GoalFunction, | |
| Transformation, | |
| SearchMethod, | |
| Constraint, | |
| PreTransformationConstraint, | |
| ModelWrapper, | |
| ), | |
| ): | |
| for key in obj.__dict__: | |
| s_obj = obj.__dict__[key] | |
| if id(s_obj) not in visited: | |
| to_cpu(s_obj) | |
| elif isinstance(obj, (list, tuple)): | |
| for item in obj: | |
| if id(item) not in visited and isinstance( | |
| item, (Transformation, Constraint, PreTransformationConstraint) | |
| ): | |
| to_cpu(item) | |
| to_cpu(self) | |
| def cuda_(self): | |
| """Move any `torch.nn.Module` models that are part of Attack to GPU.""" | |
| visited = set() | |
| def to_cuda(obj): | |
| visited.add(id(obj)) | |
| if isinstance(obj, torch.nn.Module): | |
| obj.to(textattack.shared.utils.device) | |
| elif isinstance( | |
| obj, | |
| ( | |
| Attack, | |
| GoalFunction, | |
| Transformation, | |
| SearchMethod, | |
| Constraint, | |
| PreTransformationConstraint, | |
| ModelWrapper, | |
| ), | |
| ): | |
| for key in obj.__dict__: | |
| s_obj = obj.__dict__[key] | |
| if id(s_obj) not in visited: | |
| to_cuda(s_obj) | |
| elif isinstance(obj, (list, tuple)): | |
| for item in obj: | |
| if id(item) not in visited and isinstance( | |
| item, (Transformation, Constraint, PreTransformationConstraint) | |
| ): | |
| to_cuda(item) | |
| to_cuda(self) | |
| def get_indices_to_order(self, current_text, **kwargs): | |
| """Applies ``pre_transformation_constraints`` to ``text`` to get all | |
| the indices that can be used to search and order. | |
| Args: | |
| current_text: The current ``AttackedText`` for which we need to find indices are eligible to be ordered. | |
| Returns: | |
| The length and the filtered list of indices which search methods can use to search/order. | |
| """ | |
| indices_to_order = self.transformation( | |
| current_text, | |
| pre_transformation_constraints=self.pre_transformation_constraints, | |
| return_indices=True, | |
| **kwargs, | |
| ) | |
| len_text = len(indices_to_order) | |
| # Convert indices_to_order to list for easier shuffling later | |
| return len_text, list(indices_to_order) | |
| def _get_transformations_uncached(self, current_text, original_text=None, **kwargs): | |
| """Applies ``self.transformation`` to ``text``, then filters the list | |
| of possible transformations through the applicable constraints. | |
| Args: | |
| current_text: The current ``AttackedText`` on which to perform the transformations. | |
| original_text: The original ``AttackedText`` from which the attack started. | |
| Returns: | |
| A filtered list of transformations where each transformation matches the constraints | |
| """ | |
| transformed_texts = self.transformation( | |
| current_text, | |
| pre_transformation_constraints=self.pre_transformation_constraints, | |
| **kwargs, | |
| ) | |
| return transformed_texts | |
| def get_transformations(self, current_text, original_text=None, **kwargs): | |
| """Applies ``self.transformation`` to ``text``, then filters the list | |
| of possible transformations through the applicable constraints. | |
| Args: | |
| current_text: The current ``AttackedText`` on which to perform the transformations. | |
| original_text: The original ``AttackedText`` from which the attack started. | |
| Returns: | |
| A filtered list of transformations where each transformation matches the constraints | |
| """ | |
| if not self.transformation: | |
| raise RuntimeError( | |
| "Cannot call `get_transformations` without a transformation." | |
| ) | |
| if self.use_transformation_cache: | |
| cache_key = tuple([current_text] + sorted(kwargs.items())) | |
| if utils.hashable(cache_key) and cache_key in self.transformation_cache: | |
| # promote transformed_text to the top of the LRU cache | |
| self.transformation_cache[cache_key] = self.transformation_cache[ | |
| cache_key | |
| ] | |
| transformed_texts = list(self.transformation_cache[cache_key]) | |
| else: | |
| transformed_texts = self._get_transformations_uncached( | |
| current_text, original_text, **kwargs | |
| ) | |
| if utils.hashable(cache_key): | |
| self.transformation_cache[cache_key] = tuple(transformed_texts) | |
| else: | |
| transformed_texts = self._get_transformations_uncached( | |
| current_text, original_text, **kwargs | |
| ) | |
| return self.filter_transformations( | |
| transformed_texts, current_text, original_text | |
| ) | |
| def _filter_transformations_uncached( | |
| self, transformed_texts, current_text, original_text=None | |
| ): | |
| """Filters a list of potential transformed texts based on | |
| ``self.constraints`` | |
| Args: | |
| transformed_texts: A list of candidate transformed ``AttackedText`` to filter. | |
| current_text: The current ``AttackedText`` on which the transformation was applied. | |
| original_text: The original ``AttackedText`` from which the attack started. | |
| """ | |
| filtered_texts = transformed_texts[:] | |
| for C in self.constraints: | |
| if len(filtered_texts) == 0: | |
| break | |
| if C.compare_against_original: | |
| if not original_text: | |
| raise ValueError( | |
| f"Missing `original_text` argument when constraint {type(C)} is set to compare against `original_text`" | |
| ) | |
| filtered_texts = C.call_many(filtered_texts, original_text) | |
| else: | |
| filtered_texts = C.call_many(filtered_texts, current_text) | |
| # Default to false for all original transformations. | |
| for original_transformed_text in transformed_texts: | |
| self.constraints_cache[(current_text, original_transformed_text)] = False | |
| # Set unfiltered transformations to True in the cache. | |
| for filtered_text in filtered_texts: | |
| self.constraints_cache[(current_text, filtered_text)] = True | |
| return filtered_texts | |
| def filter_transformations( | |
| self, transformed_texts, current_text, original_text=None | |
| ): | |
| """Filters a list of potential transformed texts based on | |
| ``self.constraints`` Utilizes an LRU cache to attempt to avoid | |
| recomputing common transformations. | |
| Args: | |
| transformed_texts: A list of candidate transformed ``AttackedText`` to filter. | |
| current_text: The current ``AttackedText`` on which the transformation was applied. | |
| original_text: The original ``AttackedText`` from which the attack started. | |
| """ | |
| # Remove any occurences of current_text in transformed_texts | |
| transformed_texts = [ | |
| t for t in transformed_texts if t.text != current_text.text | |
| ] | |
| # Populate cache with transformed_texts | |
| uncached_texts = [] | |
| filtered_texts = [] | |
| for transformed_text in transformed_texts: | |
| if (current_text, transformed_text) not in self.constraints_cache: | |
| uncached_texts.append(transformed_text) | |
| else: | |
| # promote transformed_text to the top of the LRU cache | |
| self.constraints_cache[ | |
| (current_text, transformed_text) | |
| ] = self.constraints_cache[(current_text, transformed_text)] | |
| if self.constraints_cache[(current_text, transformed_text)]: | |
| filtered_texts.append(transformed_text) | |
| filtered_texts += self._filter_transformations_uncached( | |
| uncached_texts, current_text, original_text=original_text | |
| ) | |
| # Sort transformations to ensure order is preserved between runs | |
| filtered_texts.sort(key=lambda t: t.text) | |
| return filtered_texts | |
| def _attack(self, initial_result): | |
| """Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in | |
| ``initial_result``. | |
| Args: | |
| initial_result: The initial ``GoalFunctionResult`` from which to perturb. | |
| Returns: | |
| A ``SuccessfulAttackResult``, ``FailedAttackResult``, | |
| or ``MaximizedAttackResult``. | |
| """ | |
| final_result = self.search_method(initial_result) | |
| self.clear_cache() | |
| if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: | |
| result = SuccessfulAttackResult( | |
| initial_result, | |
| final_result, | |
| ) | |
| elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING: | |
| result = FailedAttackResult( | |
| initial_result, | |
| final_result, | |
| ) | |
| elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING: | |
| result = MaximizedAttackResult( | |
| initial_result, | |
| final_result, | |
| ) | |
| else: | |
| raise ValueError(f"Unrecognized goal status {final_result.goal_status}") | |
| return result | |
| def attack(self, example, ground_truth_output): | |
| """Attack a single example. | |
| Args: | |
| example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`): | |
| Example to attack. It can be a single string or an `OrderedDict` where | |
| keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx. | |
| Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input. | |
| ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`): | |
| Ground truth output of `example`. | |
| For classification tasks, it should be an integer representing the ground truth label. | |
| For regression tasks (e.g. STS), it should be the target value. | |
| For seq2seq tasks (e.g. translation), it should be the target string. | |
| Returns: | |
| :class:`~textattack.attack_results.AttackResult` that represents the result of the attack. | |
| """ | |
| assert isinstance( | |
| example, (str, OrderedDict, AttackedText) | |
| ), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`." | |
| if isinstance(example, (str, OrderedDict)): | |
| example = AttackedText(example) | |
| assert isinstance( | |
| ground_truth_output, (int, str) | |
| ), "`ground_truth_output` must either be `str` or `int`." | |
| goal_function_result, _ = self.goal_function.init_attack_example( | |
| example, ground_truth_output | |
| ) | |
| if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED: | |
| return SkippedAttackResult(goal_function_result) | |
| else: | |
| result = self._attack(goal_function_result) | |
| return result | |
| def __repr__(self): | |
| """Prints attack parameters in a human-readable string. | |
| Inspired by the readability of printing PyTorch nn.Modules: | |
| https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py | |
| """ | |
| main_str = "Attack" + "(" | |
| lines = [] | |
| lines.append(utils.add_indent(f"(search_method): {self.search_method}", 2)) | |
| # self.goal_function | |
| lines.append(utils.add_indent(f"(goal_function): {self.goal_function}", 2)) | |
| # self.transformation | |
| lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2)) | |
| # self.constraints | |
| constraints_lines = [] | |
| constraints = self.constraints + self.pre_transformation_constraints | |
| if len(constraints): | |
| for i, constraint in enumerate(constraints): | |
| constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2)) | |
| constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2) | |
| else: | |
| constraints_str = "None" | |
| lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2)) | |
| # self.is_black_box | |
| lines.append(utils.add_indent(f"(is_black_box): {self.is_black_box}", 2)) | |
| main_str += "\n " + "\n ".join(lines) + "\n" | |
| main_str += ")" | |
| return main_str | |
| def __getstate__(self): | |
| state = self.__dict__.copy() | |
| state["transformation_cache"] = None | |
| state["constraints_cache"] = None | |
| return state | |
| def __setstate__(self, state): | |
| self.__dict__ = state | |
| self.transformation_cache = lru.LRU(self.transformation_cache_size) | |
| self.constraints_cache = lru.LRU(self.constraint_cache_size) | |
| __str__ = __repr__ | |