Spaces:
Runtime error
Runtime error
| """ | |
| Language Models Constraint | |
| --------------------------- | |
| """ | |
| from abc import ABC, abstractmethod | |
| from textattack.constraints import Constraint | |
| class LanguageModelConstraint(Constraint, ABC): | |
| """Determines if two sentences have a swapped word that has a similar | |
| probability according to a language model. | |
| Args: | |
| max_log_prob_diff (float): the maximum decrease in log-probability | |
| in swapped words from `x` to `x_adv` | |
| compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. | |
| Otherwise, compare it against the previous `x_adv`. | |
| """ | |
| def __init__(self, max_log_prob_diff=None, compare_against_original=True): | |
| if max_log_prob_diff is None: | |
| raise ValueError("Must set max_log_prob_diff") | |
| self.max_log_prob_diff = max_log_prob_diff | |
| super().__init__(compare_against_original) | |
| def get_log_probs_at_index(self, text_list, word_index): | |
| """Gets the log-probability of items in `text_list` at index | |
| `word_index` according to a language model.""" | |
| raise NotImplementedError() | |
| def _check_constraint(self, transformed_text, reference_text): | |
| try: | |
| indices = transformed_text.attack_attrs["newly_modified_indices"] | |
| except KeyError: | |
| raise KeyError( | |
| "Cannot apply language model constraint without `newly_modified_indices`" | |
| ) | |
| for i in indices: | |
| probs = self.get_log_probs_at_index((reference_text, transformed_text), i) | |
| if len(probs) != 2: | |
| raise ValueError( | |
| f"Error: get_log_probs_at_index returned {len(probs)} values for 2 inputs" | |
| ) | |
| ref_prob, transformed_prob = probs | |
| if transformed_prob <= ref_prob - self.max_log_prob_diff: | |
| return False | |
| return True | |
| def extra_repr_keys(self): | |
| return ["max_log_prob_diff"] + super().extra_repr_keys() | |