Spaces:
Runtime error
Runtime error
| """ | |
| Misc Validators | |
| ================= | |
| Validators ensure compatibility between search methods, transformations, constraints, and goal functions. | |
| """ | |
| import re | |
| import textattack | |
| from textattack.goal_functions import ( | |
| InputReduction, | |
| MinimizeBleu, | |
| NonOverlappingOutput, | |
| TargetedClassification, | |
| UntargetedClassification, | |
| ) | |
| from . import logger | |
| # A list of goal functions and the corresponding available models. | |
| MODELS_BY_GOAL_FUNCTIONS = { | |
| (TargetedClassification, UntargetedClassification, InputReduction): [ | |
| r"^textattack.models.helpers.lstm_for_classification.*", | |
| r"^textattack.models.helpers.word_cnn_for_classification.*", | |
| r"^transformers.modeling_\w*\.\w*ForSequenceClassification$", | |
| ], | |
| ( | |
| NonOverlappingOutput, | |
| MinimizeBleu, | |
| ): [ | |
| r"^textattack.models.helpers.t5_for_text_to_text.*", | |
| ], | |
| } | |
| # Unroll the `MODELS_BY_GOAL_FUNCTIONS` dictionary into a dictionary that has | |
| # a key for each goal function. (Note the plurality here that distinguishes | |
| # the two variables from one another.) | |
| MODELS_BY_GOAL_FUNCTION = {} | |
| for goal_functions, matching_model_globs in MODELS_BY_GOAL_FUNCTIONS.items(): | |
| for goal_function in goal_functions: | |
| MODELS_BY_GOAL_FUNCTION[goal_function] = matching_model_globs | |
| def validate_model_goal_function_compatibility(goal_function_class, model_class): | |
| """Determines if ``model_class`` is task-compatible with | |
| ``goal_function_class``. | |
| For example, a text-generative model like one intended for | |
| translation or summarization would not be compatible with a goal | |
| function that requires probability scores, like the | |
| UntargetedGoalFunction. | |
| """ | |
| # Verify that this is a valid goal function. | |
| try: | |
| matching_model_globs = MODELS_BY_GOAL_FUNCTION[goal_function_class] | |
| except KeyError: | |
| matching_model_globs = [] | |
| logger.warn(f"No entry found for goal function {goal_function_class}.") | |
| # Get options for this goal function. | |
| # model_module = model_class.__module__ | |
| model_module_path = ".".join((model_class.__module__, model_class.__name__)) | |
| # Ensure the model matches one of these options. | |
| for glob in matching_model_globs: | |
| if re.match(glob, model_module_path): | |
| logger.info( | |
| f"Goal function {goal_function_class} compatible with model {model_class.__name__}." | |
| ) | |
| return | |
| # If we got here, the model does not match the intended goal function. | |
| for goal_functions, globs in MODELS_BY_GOAL_FUNCTIONS.items(): | |
| for glob in globs: | |
| if re.match(glob, model_module_path): | |
| logger.warn( | |
| f"Unknown if model {model_class.__name__} compatible with provided goal function {goal_function_class}." | |
| f" Found match with other goal functions: {goal_functions}." | |
| ) | |
| return | |
| # If it matches another goal function, warn user. | |
| # Otherwise, this is an unknown model–perhaps user-provided, or we forgot to | |
| # update the corresponding dictionary. Warn user and return. | |
| logger.warn( | |
| f"Unknown if model of class {model_class} compatible with goal function {goal_function_class}." | |
| ) | |
| def validate_model_gradient_word_swap_compatibility(model): | |
| """Determines if ``model`` is task-compatible with | |
| ``GradientBasedWordSwap``. | |
| We can only take the gradient with respect to an individual word if | |
| the model uses a word-based tokenizer. | |
| """ | |
| if isinstance(model, textattack.models.helpers.LSTMForClassification): | |
| return True | |
| else: | |
| raise ValueError(f"Cannot perform GradientBasedWordSwap on model {model}.") | |
| def transformation_consists_of(transformation, transformation_classes): | |
| """Determines if ``transformation`` is or consists only of instances of a | |
| class in ``transformation_classes``""" | |
| from textattack.transformations import CompositeTransformation | |
| if isinstance(transformation, CompositeTransformation): | |
| for t in transformation.transformations: | |
| if not transformation_consists_of(t, transformation_classes): | |
| return False | |
| return True | |
| else: | |
| for transformation_class in transformation_classes: | |
| if isinstance(transformation, transformation_class): | |
| return True | |
| return False | |
| def transformation_consists_of_word_swaps(transformation): | |
| """Determines if ``transformation`` is a word swap or consists of only word | |
| swaps.""" | |
| from textattack.transformations import WordSwap, WordSwapGradientBased | |
| return transformation_consists_of(transformation, [WordSwap, WordSwapGradientBased]) | |
| def transformation_consists_of_word_swaps_and_deletions(transformation): | |
| """Determines if ``transformation`` is a word swap or consists of only word | |
| swaps and deletions.""" | |
| from textattack.transformations import WordDeletion, WordSwap, WordSwapGradientBased | |
| return transformation_consists_of( | |
| transformation, [WordDeletion, WordSwap, WordSwapGradientBased] | |
| ) | |