Spaces:
Runtime error
Runtime error
| """.. _attacked_text: | |
| Attacked Text Class | |
| ===================== | |
| A helper class that represents a string that can be attacked. | |
| """ | |
| from collections import OrderedDict | |
| import math | |
| import flair | |
| from flair.data import Sentence | |
| import numpy as np | |
| import torch | |
| import textattack | |
| from .utils import device, words_from_text | |
| flair.device = device | |
| class AttackedText: | |
| """A helper class that represents a string that can be attacked. | |
| Models that take multiple sentences as input separate them by ``SPLIT_TOKEN``. | |
| Attacks "see" the entire input, joined into one string, without the split token. | |
| ``AttackedText`` instances that were perturbed from other ``AttackedText`` | |
| objects contain a pointer to the previous text | |
| (``attack_attrs["previous_attacked_text"]``), so that the full chain of | |
| perturbations might be reconstructed by using this key to form a linked | |
| list. | |
| Args: | |
| text (string): The string that this AttackedText represents | |
| attack_attrs (dict): Dictionary of various attributes stored | |
| during the course of an attack. | |
| """ | |
| SPLIT_TOKEN = "<SPLIT>" | |
| def __init__(self, text_input, attack_attrs=None): | |
| # Read in ``text_input`` as a string or OrderedDict. | |
| if isinstance(text_input, str): | |
| self._text_input = OrderedDict([("text", text_input)]) | |
| elif isinstance(text_input, OrderedDict): | |
| self._text_input = text_input | |
| else: | |
| raise TypeError( | |
| f"Invalid text_input type {type(text_input)} (required str or OrderedDict)" | |
| ) | |
| # Process input lazily. | |
| self._words = None | |
| self._words_per_input = None | |
| self._pos_tags = None | |
| self._ner_tags = None | |
| # Format text inputs. | |
| self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()]) | |
| if attack_attrs is None: | |
| self.attack_attrs = dict() | |
| elif isinstance(attack_attrs, dict): | |
| self.attack_attrs = attack_attrs | |
| else: | |
| raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}") | |
| # Indices of words from the *original* text. Allows us to map | |
| # indices between original text and this text, and vice-versa. | |
| self.attack_attrs.setdefault("original_index_map", np.arange(self.num_words)) | |
| # A list of all indices in *this* text that have been modified. | |
| self.attack_attrs.setdefault("modified_indices", set()) | |
| def __eq__(self, other): | |
| """Compares two text instances to make sure they have the same attack | |
| attributes. | |
| Since some elements stored in ``self.attack_attrs`` may be numpy | |
| arrays, we have to take special care when comparing them. | |
| """ | |
| if not (self.text == other.text): | |
| return False | |
| if len(self.attack_attrs) != len(other.attack_attrs): | |
| return False | |
| for key in self.attack_attrs: | |
| if key not in other.attack_attrs: | |
| return False | |
| elif isinstance(self.attack_attrs[key], np.ndarray): | |
| if not (self.attack_attrs[key].shape == other.attack_attrs[key].shape): | |
| return False | |
| elif not (self.attack_attrs[key] == other.attack_attrs[key]).all(): | |
| return False | |
| else: | |
| if not self.attack_attrs[key] == other.attack_attrs[key]: | |
| return False | |
| return True | |
| def __hash__(self): | |
| return hash(self.text) | |
| def free_memory(self): | |
| """Delete items that take up memory. | |
| Can be called once the AttackedText is only needed to display. | |
| """ | |
| if "previous_attacked_text" in self.attack_attrs: | |
| self.attack_attrs["previous_attacked_text"].free_memory() | |
| self.attack_attrs.pop("previous_attacked_text", None) | |
| self.attack_attrs.pop("last_transformation", None) | |
| for key in self.attack_attrs: | |
| if isinstance(self.attack_attrs[key], torch.Tensor): | |
| self.attack_attrs.pop(key, None) | |
| def text_window_around_index(self, index, window_size): | |
| """The text window of ``window_size`` words centered around | |
| ``index``.""" | |
| length = self.num_words | |
| half_size = (window_size - 1) / 2.0 | |
| if index - half_size < 0: | |
| start = 0 | |
| end = min(window_size - 1, length - 1) | |
| elif index + half_size >= length: | |
| start = max(0, length - window_size) | |
| end = length - 1 | |
| else: | |
| start = index - math.ceil(half_size) | |
| end = index + math.floor(half_size) | |
| text_idx_start = self._text_index_of_word_index(start) | |
| text_idx_end = self._text_index_of_word_index(end) + len(self.words[end]) | |
| return self.text[text_idx_start:text_idx_end] | |
| def pos_of_word_index(self, desired_word_idx): | |
| """Returns the part-of-speech of the word at index `word_idx`. | |
| Uses FLAIR part-of-speech tagger. | |
| """ | |
| if not self._pos_tags: | |
| sentence = Sentence( | |
| self.text, | |
| use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), | |
| ) | |
| textattack.shared.utils.flair_tag(sentence) | |
| self._pos_tags = sentence | |
| flair_word_list, flair_pos_list = textattack.shared.utils.zip_flair_result( | |
| self._pos_tags | |
| ) | |
| for word_idx, word in enumerate(self.words): | |
| assert ( | |
| word in flair_word_list | |
| ), "word absent in flair returned part-of-speech tags" | |
| word_idx_in_flair_tags = flair_word_list.index(word) | |
| if word_idx == desired_word_idx: | |
| return flair_pos_list[word_idx_in_flair_tags] | |
| else: | |
| flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :] | |
| flair_pos_list = flair_pos_list[word_idx_in_flair_tags + 1 :] | |
| raise ValueError( | |
| f"Did not find word from index {desired_word_idx} in flair POS tag" | |
| ) | |
| def ner_of_word_index(self, desired_word_idx, model_name="ner"): | |
| """Returns the ner tag of the word at index `word_idx`. | |
| Uses FLAIR ner tagger. | |
| """ | |
| if not self._ner_tags: | |
| sentence = Sentence( | |
| self.text, | |
| use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), | |
| ) | |
| textattack.shared.utils.flair_tag(sentence, model_name) | |
| self._ner_tags = sentence | |
| flair_word_list, flair_ner_list = textattack.shared.utils.zip_flair_result( | |
| self._ner_tags, "ner" | |
| ) | |
| for word_idx, word in enumerate(flair_word_list): | |
| word_idx_in_flair_tags = flair_word_list.index(word) | |
| if word_idx == desired_word_idx: | |
| return flair_ner_list[word_idx_in_flair_tags] | |
| else: | |
| flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :] | |
| flair_ner_list = flair_ner_list[word_idx_in_flair_tags + 1 :] | |
| raise ValueError( | |
| f"Did not find word from index {desired_word_idx} in flair POS tag" | |
| ) | |
| def _text_index_of_word_index(self, i): | |
| """Returns the index of word ``i`` in self.text.""" | |
| pre_words = self.words[: i + 1] | |
| lower_text = self.text.lower() | |
| # Find all words until `i` in string. | |
| look_after_index = 0 | |
| for word in pre_words: | |
| look_after_index = lower_text.find(word.lower(), look_after_index) + len( | |
| word | |
| ) | |
| look_after_index -= len(self.words[i]) | |
| return look_after_index | |
| def text_until_word_index(self, i): | |
| """Returns the text before the beginning of word at index ``i``.""" | |
| look_after_index = self._text_index_of_word_index(i) | |
| return self.text[:look_after_index] | |
| def text_after_word_index(self, i): | |
| """Returns the text after the end of word at index ``i``.""" | |
| # Get index of beginning of word then jump to end of word. | |
| look_after_index = self._text_index_of_word_index(i) + len(self.words[i]) | |
| return self.text[look_after_index:] | |
| def first_word_diff(self, other_attacked_text): | |
| """Returns the first word in self.words that differs from | |
| other_attacked_text. | |
| Useful for word swap strategies. | |
| """ | |
| w1 = self.words | |
| w2 = other_attacked_text.words | |
| for i in range(min(len(w1), len(w2))): | |
| if w1[i] != w2[i]: | |
| return w1[i] | |
| return None | |
| def first_word_diff_index(self, other_attacked_text): | |
| """Returns the index of the first word in self.words that differs from | |
| other_attacked_text. | |
| Useful for word swap strategies. | |
| """ | |
| w1 = self.words | |
| w2 = other_attacked_text.words | |
| for i in range(min(len(w1), len(w2))): | |
| if w1[i] != w2[i]: | |
| return i | |
| return None | |
| def all_words_diff(self, other_attacked_text): | |
| """Returns the set of indices for which this and other_attacked_text | |
| have different words.""" | |
| indices = set() | |
| w1 = self.words | |
| w2 = other_attacked_text.words | |
| for i in range(min(len(w1), len(w2))): | |
| if w1[i] != w2[i]: | |
| indices.add(i) | |
| return indices | |
| def ith_word_diff(self, other_attacked_text, i): | |
| """Returns whether the word at index i differs from | |
| other_attacked_text.""" | |
| w1 = self.words | |
| w2 = other_attacked_text.words | |
| if len(w1) - 1 < i or len(w2) - 1 < i: | |
| return True | |
| return w1[i] != w2[i] | |
| def words_diff_num(self, other_attacked_text): | |
| # using edit distance to calculate words diff num | |
| def generate_tokens(words): | |
| result = {} | |
| idx = 1 | |
| for w in words: | |
| if w not in result: | |
| result[w] = idx | |
| idx += 1 | |
| return result | |
| def words_to_tokens(words, tokens): | |
| result = [] | |
| for w in words: | |
| result.append(tokens[w]) | |
| return result | |
| def edit_distance(w1_t, w2_t): | |
| matrix = [ | |
| [i + j for j in range(len(w2_t) + 1)] for i in range(len(w1_t) + 1) | |
| ] | |
| for i in range(1, len(w1_t) + 1): | |
| for j in range(1, len(w2_t) + 1): | |
| if w1_t[i - 1] == w2_t[j - 1]: | |
| d = 0 | |
| else: | |
| d = 1 | |
| matrix[i][j] = min( | |
| matrix[i - 1][j] + 1, | |
| matrix[i][j - 1] + 1, | |
| matrix[i - 1][j - 1] + d, | |
| ) | |
| return matrix[len(w1_t)][len(w2_t)] | |
| def cal_dif(w1, w2): | |
| tokens = generate_tokens(w1 + w2) | |
| w1_t = words_to_tokens(w1, tokens) | |
| w2_t = words_to_tokens(w2, tokens) | |
| return edit_distance(w1_t, w2_t) | |
| w1 = self.words | |
| w2 = other_attacked_text.words | |
| return cal_dif(w1, w2) | |
| def convert_from_original_idxs(self, idxs): | |
| """Takes indices of words from original string and converts them to | |
| indices of the same words in the current string. | |
| Uses information from | |
| ``self.attack_attrs['original_index_map']``, which maps word | |
| indices from the original to perturbed text. | |
| """ | |
| if len(self.attack_attrs["original_index_map"]) == 0: | |
| return idxs | |
| elif isinstance(idxs, set): | |
| idxs = list(idxs) | |
| elif not isinstance(idxs, [list, np.ndarray]): | |
| raise TypeError( | |
| f"convert_from_original_idxs got invalid idxs type {type(idxs)}" | |
| ) | |
| return [self.attack_attrs["original_index_map"][i] for i in idxs] | |
| def replace_words_at_indices(self, indices, new_words): | |
| """This code returns a new AttackedText object where the word at | |
| ``index`` is replaced with a new word.""" | |
| if len(indices) != len(new_words): | |
| raise ValueError( | |
| f"Cannot replace {len(new_words)} words at {len(indices)} indices." | |
| ) | |
| words = self.words[:] | |
| for i, new_word in zip(indices, new_words): | |
| if not isinstance(new_word, str): | |
| raise TypeError( | |
| f"replace_words_at_indices requires ``str`` words, got {type(new_word)}" | |
| ) | |
| if (i < 0) or (i > len(words)): | |
| raise ValueError(f"Cannot assign word at index {i}") | |
| words[i] = new_word | |
| return self.generate_new_attacked_text(words) | |
| def replace_word_at_index(self, index, new_word): | |
| """This code returns a new AttackedText object where the word at | |
| ``index`` is replaced with a new word.""" | |
| if not isinstance(new_word, str): | |
| raise TypeError( | |
| f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}" | |
| ) | |
| return self.replace_words_at_indices([index], [new_word]) | |
| def delete_word_at_index(self, index): | |
| """This code returns a new AttackedText object where the word at | |
| ``index`` is removed.""" | |
| return self.replace_word_at_index(index, "") | |
| def insert_text_after_word_index(self, index, text): | |
| """Inserts a string before word at index ``index`` and attempts to add | |
| appropriate spacing.""" | |
| if not isinstance(text, str): | |
| raise TypeError(f"text must be an str, got type {type(text)}") | |
| word_at_index = self.words[index] | |
| new_text = " ".join((word_at_index, text)) | |
| return self.replace_word_at_index(index, new_text) | |
| def insert_text_before_word_index(self, index, text): | |
| """Inserts a string before word at index ``index`` and attempts to add | |
| appropriate spacing.""" | |
| if not isinstance(text, str): | |
| raise TypeError(f"text must be an str, got type {type(text)}") | |
| word_at_index = self.words[index] | |
| # TODO if ``word_at_index`` is at the beginning of a sentence, we should | |
| # optionally capitalize ``text``. | |
| new_text = " ".join((text, word_at_index)) | |
| return self.replace_word_at_index(index, new_text) | |
| def get_deletion_indices(self): | |
| return self.attack_attrs["original_index_map"][ | |
| self.attack_attrs["original_index_map"] == -1 | |
| ] | |
| def generate_new_attacked_text(self, new_words): | |
| """Returns a new AttackedText object and replaces old list of words | |
| with a new list of words, but preserves the punctuation and spacing of | |
| the original message. | |
| ``self.words`` is a list of the words in the current text with | |
| punctuation removed. However, each "word" in ``new_words`` could | |
| be an empty string, representing a word deletion, or a string | |
| with multiple space-separated words, representation an insertion | |
| of one or more words. | |
| """ | |
| perturbed_text = "" | |
| original_text = AttackedText.SPLIT_TOKEN.join(self._text_input.values()) | |
| new_attack_attrs = dict() | |
| if "label_names" in self.attack_attrs: | |
| new_attack_attrs["label_names"] = self.attack_attrs["label_names"] | |
| new_attack_attrs["newly_modified_indices"] = set() | |
| # Point to previously monitored text. | |
| new_attack_attrs["previous_attacked_text"] = self | |
| # Use `new_attack_attrs` to track indices with respect to the original | |
| # text. | |
| new_attack_attrs["modified_indices"] = self.attack_attrs[ | |
| "modified_indices" | |
| ].copy() | |
| new_attack_attrs["original_index_map"] = self.attack_attrs[ | |
| "original_index_map" | |
| ].copy() | |
| new_i = 0 | |
| # Create the new attacked text by swapping out words from the original | |
| # text with a sequence of 0+ words in the new text. | |
| for i, (input_word, adv_word_seq) in enumerate(zip(self.words, new_words)): | |
| word_start = original_text.index(input_word) | |
| word_end = word_start + len(input_word) | |
| perturbed_text += original_text[:word_start] | |
| original_text = original_text[word_end:] | |
| adv_words = words_from_text(adv_word_seq) | |
| adv_num_words = len(adv_words) | |
| num_words_diff = adv_num_words - len(words_from_text(input_word)) | |
| # Track indices on insertions and deletions. | |
| if num_words_diff != 0: | |
| # Re-calculated modified indices. If words are inserted or deleted, | |
| # they could change. | |
| shifted_modified_indices = set() | |
| for modified_idx in new_attack_attrs["modified_indices"]: | |
| if modified_idx < i: | |
| shifted_modified_indices.add(modified_idx) | |
| elif modified_idx > i: | |
| shifted_modified_indices.add(modified_idx + num_words_diff) | |
| else: | |
| pass | |
| new_attack_attrs["modified_indices"] = shifted_modified_indices | |
| # Track insertions and deletions wrt original text. | |
| # original_modification_idx = i | |
| new_idx_map = new_attack_attrs["original_index_map"].copy() | |
| if num_words_diff == -1: | |
| # Word deletion | |
| new_idx_map[new_idx_map == i] = -1 | |
| new_idx_map[new_idx_map > i] += num_words_diff | |
| if num_words_diff > 0 and input_word != adv_words[0]: | |
| # If insertion happens before the `input_word` | |
| new_idx_map[new_idx_map == i] += num_words_diff | |
| new_attack_attrs["original_index_map"] = new_idx_map | |
| # Move pointer and save indices of new modified words. | |
| for j in range(i, i + adv_num_words): | |
| if input_word != adv_word_seq: | |
| new_attack_attrs["modified_indices"].add(new_i) | |
| new_attack_attrs["newly_modified_indices"].add(new_i) | |
| new_i += 1 | |
| # Check spaces for deleted text. | |
| if adv_num_words == 0 and len(original_text): | |
| # Remove extra space (or else there would be two spaces for each | |
| # deleted word). | |
| # @TODO What to do with punctuation in this case? This behavior is undefined. | |
| if i == 0: | |
| # If the first word was deleted, take a subsequent space. | |
| if original_text[0] == " ": | |
| original_text = original_text[1:] | |
| else: | |
| # If a word other than the first was deleted, take a preceding space. | |
| if perturbed_text[-1] == " ": | |
| perturbed_text = perturbed_text[:-1] | |
| # Add substitute word(s) to new sentence. | |
| perturbed_text += adv_word_seq | |
| perturbed_text += original_text # Add all of the ending punctuation. | |
| # Add pointer to self so chain of replacements can be reconstructed. | |
| new_attack_attrs["prev_attacked_text"] = self | |
| # Reform perturbed_text into an OrderedDict. | |
| perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN) | |
| perturbed_input = OrderedDict( | |
| zip(self._text_input.keys(), perturbed_input_texts) | |
| ) | |
| return AttackedText(perturbed_input, attack_attrs=new_attack_attrs) | |
| def words_diff_ratio(self, x): | |
| """Get the ratio of words difference between current text and `x`. | |
| Note that current text and `x` must have same number of words. | |
| """ | |
| assert self.num_words == x.num_words | |
| return float(np.sum(self.words != x.words)) / self.num_words | |
| def align_with_model_tokens(self, model_wrapper): | |
| """Align AttackedText's `words` with target model's tokenization scheme | |
| (e.g. word, character, subword). Specifically, we map each word to list | |
| of indices of tokens that compose the word (e.g. embedding --> ["em", | |
| "##bed", "##ding"]) | |
| Args: | |
| model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model | |
| Returns: | |
| word2token_mapping (dict[int, list[int]]): Dictionary that maps i-th word to list of indices. | |
| """ | |
| tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0] | |
| word2token_mapping = {} | |
| j = 0 | |
| last_matched = 0 | |
| for i, word in enumerate(self.words): | |
| matched_tokens = [] | |
| while j < len(tokens) and len(word) > 0: | |
| token = tokens[j].lower() | |
| idx = word.lower().find(token) | |
| if idx == 0: | |
| word = word[idx + len(token) :] | |
| matched_tokens.append(j) | |
| last_matched = j | |
| j += 1 | |
| if not matched_tokens: | |
| word2token_mapping[i] = None | |
| j = last_matched | |
| else: | |
| word2token_mapping[i] = matched_tokens | |
| return word2token_mapping | |
| def tokenizer_input(self): | |
| """The tuple of inputs to be passed to the tokenizer.""" | |
| input_tuple = tuple(self._text_input.values()) | |
| # Prefer to return a string instead of a tuple with a single value. | |
| if len(input_tuple) == 1: | |
| return input_tuple[0] | |
| else: | |
| return input_tuple | |
| def column_labels(self): | |
| """Returns the labels for this text's columns. | |
| For single-sequence inputs, this simply returns ['text']. | |
| """ | |
| return list(self._text_input.keys()) | |
| def words_per_input(self): | |
| """Returns a list of lists of words corresponding to each input.""" | |
| if not self._words_per_input: | |
| self._words_per_input = [ | |
| words_from_text(_input) for _input in self._text_input.values() | |
| ] | |
| return self._words_per_input | |
| def words(self): | |
| if not self._words: | |
| self._words = words_from_text(self.text) | |
| return self._words | |
| def text(self): | |
| """Represents full text input. | |
| Multiply inputs are joined with a line break. | |
| """ | |
| return "\n".join(self._text_input.values()) | |
| def num_words(self): | |
| """Returns the number of words in the sequence.""" | |
| return len(self.words) | |
| def newly_swapped_words(self): | |
| return [self.words[i] for i in self.attack_attrs["newly_modified_indices"]] | |
| def printable_text(self, key_color="bold", key_color_method=None): | |
| """Represents full text input. Adds field descriptions. | |
| For example, entailment inputs look like: | |
| ``` | |
| premise: ... | |
| hypothesis: ... | |
| ``` | |
| """ | |
| # For single-sequence inputs, don't show a prefix. | |
| if len(self._text_input) == 1: | |
| return next(iter(self._text_input.values())) | |
| # For multiple-sequence inputs, show a prefix and a colon. Optionally, | |
| # color the key. | |
| else: | |
| if key_color_method: | |
| def ck(k): | |
| return textattack.shared.utils.color_text( | |
| k, key_color, key_color_method | |
| ) | |
| else: | |
| def ck(k): | |
| return k | |
| return "\n".join( | |
| f"{ck(key.capitalize())}: {value}" | |
| for key, value in self._text_input.items() | |
| ) | |
| def __repr__(self): | |
| return f'<AttackedText "{self.text}">' | |