Spaces:
Runtime error
Runtime error
| """.. _goal_function: | |
| GoalFunction Class | |
| =========================================================== | |
| """ | |
| from abc import ABC, abstractmethod | |
| import lru | |
| import numpy as np | |
| import torch | |
| from textattack.goal_function_results.goal_function_result import ( | |
| GoalFunctionResultStatus, | |
| ) | |
| from textattack.shared import validators | |
| from textattack.shared.utils import ReprMixin | |
| class GoalFunction(ReprMixin, ABC): | |
| """Evaluates how well a perturbed attacked_text object is achieving a | |
| specified goal. | |
| Args: | |
| model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): | |
| The victim model to attack. | |
| maximizable(:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| Whether the goal function is maximizable, as opposed to a boolean result of success or failure. | |
| query_budget (:obj:`float`, `optional`, defaults to :obj:`float("in")`): | |
| The maximum number of model queries allowed. | |
| model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**20`): | |
| The maximum number of items to keep in the model results cache at once. | |
| """ | |
| def __init__( | |
| self, | |
| model_wrapper, | |
| maximizable=False, | |
| use_cache=True, | |
| query_budget=float("inf"), | |
| model_batch_size=32, | |
| model_cache_size=2**20, | |
| ): | |
| validators.validate_model_goal_function_compatibility( | |
| self.__class__, model_wrapper.model.__class__ | |
| ) | |
| self.model = model_wrapper | |
| self.maximizable = maximizable | |
| self.use_cache = use_cache | |
| self.query_budget = query_budget | |
| self.batch_size = model_batch_size | |
| if self.use_cache: | |
| self._call_model_cache = lru.LRU(model_cache_size) | |
| else: | |
| self._call_model_cache = None | |
| def clear_cache(self): | |
| if self.use_cache: | |
| self._call_model_cache.clear() | |
| def init_attack_example(self, attacked_text, ground_truth_output): | |
| """Called before attacking ``attacked_text`` to 'reset' the goal | |
| function and set properties for this example.""" | |
| self.initial_attacked_text = attacked_text | |
| self.ground_truth_output = ground_truth_output | |
| self.num_queries = 0 | |
| result, _ = self.get_result(attacked_text, check_skip=True) | |
| return result, _ | |
| def get_output(self, attacked_text): | |
| """Returns output for display based on the result of calling the | |
| model.""" | |
| return self._get_displayed_output(self._call_model([attacked_text])[0]) | |
| def get_result(self, attacked_text, **kwargs): | |
| """A helper method that queries ``self.get_results`` with a single | |
| ``AttackedText`` object.""" | |
| results, search_over = self.get_results([attacked_text], **kwargs) | |
| result = results[0] if len(results) else None | |
| return result, search_over | |
| def get_results(self, attacked_text_list, check_skip=False): | |
| """For each attacked_text object in attacked_text_list, returns a | |
| result consisting of whether or not the goal has been achieved, the | |
| output for display purposes, and a score. | |
| Additionally returns whether the search is over due to the query | |
| budget. | |
| """ | |
| results = [] | |
| if self.query_budget < float("inf"): | |
| queries_left = self.query_budget - self.num_queries | |
| attacked_text_list = attacked_text_list[:queries_left] | |
| self.num_queries += len(attacked_text_list) | |
| model_outputs = self._call_model(attacked_text_list) | |
| for attacked_text, raw_output in zip(attacked_text_list, model_outputs): | |
| displayed_output = self._get_displayed_output(raw_output) | |
| goal_status = self._get_goal_status( | |
| raw_output, attacked_text, check_skip=check_skip | |
| ) | |
| goal_function_score = self._get_score(raw_output, attacked_text) | |
| results.append( | |
| self._goal_function_result_type()( | |
| attacked_text, | |
| raw_output, | |
| displayed_output, | |
| goal_status, | |
| goal_function_score, | |
| self.num_queries, | |
| self.ground_truth_output, | |
| ) | |
| ) | |
| return results, self.num_queries == self.query_budget | |
| def _get_goal_status(self, model_output, attacked_text, check_skip=False): | |
| should_skip = check_skip and self._should_skip(model_output, attacked_text) | |
| if should_skip: | |
| return GoalFunctionResultStatus.SKIPPED | |
| if self.maximizable: | |
| return GoalFunctionResultStatus.MAXIMIZING | |
| if self._is_goal_complete(model_output, attacked_text): | |
| return GoalFunctionResultStatus.SUCCEEDED | |
| return GoalFunctionResultStatus.SEARCHING | |
| def _is_goal_complete(self, model_output, attacked_text): | |
| raise NotImplementedError() | |
| def _should_skip(self, model_output, attacked_text): | |
| return self._is_goal_complete(model_output, attacked_text) | |
| def _get_score(self, model_output, attacked_text): | |
| raise NotImplementedError() | |
| def _get_displayed_output(self, raw_output): | |
| return raw_output | |
| def _goal_function_result_type(self): | |
| """Returns the class of this goal function's results.""" | |
| raise NotImplementedError() | |
| def _process_model_outputs(self, inputs, outputs): | |
| """Processes and validates a list of model outputs. | |
| This is a task-dependent operation. For example, classification | |
| outputs need to make sure they have a softmax applied. | |
| """ | |
| raise NotImplementedError() | |
| def _call_model_uncached(self, attacked_text_list): | |
| """Queries model and returns outputs for a list of AttackedText | |
| objects.""" | |
| if not len(attacked_text_list): | |
| return [] | |
| inputs = [at.tokenizer_input for at in attacked_text_list] | |
| outputs = [] | |
| i = 0 | |
| while i < len(inputs): | |
| batch = inputs[i : i + self.batch_size] | |
| batch_preds = self.model(batch) | |
| # Some seq-to-seq models will return a single string as a prediction | |
| # for a single-string list. Wrap these in a list. | |
| if isinstance(batch_preds, str): | |
| batch_preds = [batch_preds] | |
| # Get PyTorch tensors off of other devices. | |
| if isinstance(batch_preds, torch.Tensor): | |
| batch_preds = batch_preds.cpu() | |
| if isinstance(batch_preds, list): | |
| outputs.extend(batch_preds) | |
| elif isinstance(batch_preds, np.ndarray): | |
| outputs.append(torch.tensor(batch_preds)) | |
| else: | |
| outputs.append(batch_preds) | |
| i += self.batch_size | |
| if isinstance(outputs[0], torch.Tensor): | |
| outputs = torch.cat(outputs, dim=0) | |
| assert len(inputs) == len( | |
| outputs | |
| ), f"Got {len(outputs)} outputs for {len(inputs)} inputs" | |
| return self._process_model_outputs(attacked_text_list, outputs) | |
| def _call_model(self, attacked_text_list): | |
| """Gets predictions for a list of ``AttackedText`` objects. | |
| Gets prediction from cache if possible. If prediction is not in | |
| the cache, queries model and stores prediction in cache. | |
| """ | |
| if not self.use_cache: | |
| return self._call_model_uncached(attacked_text_list) | |
| else: | |
| uncached_list = [] | |
| for text in attacked_text_list: | |
| if text in self._call_model_cache: | |
| # Re-write value in cache. This moves the key to the top of the | |
| # LRU cache and prevents the unlikely event that the text | |
| # is overwritten when we store the inputs from `uncached_list`. | |
| self._call_model_cache[text] = self._call_model_cache[text] | |
| else: | |
| uncached_list.append(text) | |
| uncached_list = [ | |
| text | |
| for text in attacked_text_list | |
| if text not in self._call_model_cache | |
| ] | |
| outputs = self._call_model_uncached(uncached_list) | |
| for text, output in zip(uncached_list, outputs): | |
| self._call_model_cache[text] = output | |
| all_outputs = [self._call_model_cache[text] for text in attacked_text_list] | |
| return all_outputs | |
| def extra_repr_keys(self): | |
| attrs = [] | |
| if self.query_budget < float("inf"): | |
| attrs.append("query_budget") | |
| if self.maximizable: | |
| attrs.append("maximizable") | |
| return attrs | |
| def __getstate__(self): | |
| state = self.__dict__.copy() | |
| if self.use_cache: | |
| state["_call_model_cache"] = self._call_model_cache.get_size() | |
| return state | |
| def __setstate__(self, state): | |
| self.__dict__ = state | |
| if self.use_cache: | |
| self._call_model_cache = lru.LRU(state["_call_model_cache"]) | |