Spaces:
Runtime error
Runtime error
| """ | |
| GoalFunctionResult class | |
| ==================================== | |
| """ | |
| from abc import ABC, abstractmethod | |
| import torch | |
| from textattack.shared import utils | |
| class GoalFunctionResultStatus: | |
| SUCCEEDED = 0 | |
| SEARCHING = 1 # In process of searching for a success | |
| MAXIMIZING = 2 | |
| SKIPPED = 3 | |
| class GoalFunctionResult(ABC): | |
| """Represents the result of a goal function evaluating a AttackedText | |
| object. | |
| Args: | |
| attacked_text: The sequence that was evaluated. | |
| output: The display-friendly output. | |
| goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal. | |
| score: A score representing how close the model is to achieving its goal. | |
| num_queries: How many model queries have been used | |
| ground_truth_output: The ground truth output | |
| """ | |
| def __init__( | |
| self, | |
| attacked_text, | |
| raw_output, | |
| output, | |
| goal_status, | |
| score, | |
| num_queries, | |
| ground_truth_output, | |
| goal_function_result_type="", | |
| ): | |
| self.attacked_text = attacked_text | |
| self.raw_output = raw_output | |
| self.output = output | |
| self.score = score | |
| self.goal_status = goal_status | |
| self.num_queries = num_queries | |
| self.ground_truth_output = ground_truth_output | |
| self.goal_function_result_type = goal_function_result_type | |
| if isinstance(self.raw_output, torch.Tensor): | |
| self.raw_output = self.raw_output.numpy() | |
| if isinstance(self.score, torch.Tensor): | |
| self.score = self.score.item() | |
| def __repr__(self): | |
| main_str = "GoalFunctionResult( " | |
| lines = [] | |
| lines.append( | |
| utils.add_indent( | |
| f"(goal_function_result_type): {self.goal_function_result_type}", 2 | |
| ) | |
| ) | |
| lines.append(utils.add_indent(f"(attacked_text): {self.attacked_text.text}", 2)) | |
| lines.append( | |
| utils.add_indent(f"(ground_truth_output): {self.ground_truth_output}", 2) | |
| ) | |
| lines.append(utils.add_indent(f"(model_output): {self.output}", 2)) | |
| lines.append(utils.add_indent(f"(score): {self.score}", 2)) | |
| main_str += "\n " + "\n ".join(lines) + "\n" | |
| main_str += ")" | |
| return main_str | |
| def get_text_color_input(self): | |
| """A string representing the color this result's changed portion should | |
| be if it represents the original input.""" | |
| raise NotImplementedError() | |
| def get_text_color_perturbed(self): | |
| """A string representing the color this result's changed portion should | |
| be if it represents the perturbed input.""" | |
| raise NotImplementedError() | |
| def get_colored_output(self, color_method=None): | |
| """Returns a string representation of this result's output, colored | |
| according to `color_method`.""" | |
| raise NotImplementedError() | |