Spaces:
Sleeping
Sleeping
| """ | |
| ClassificationGoalFunctionResult Class | |
| ======================================== | |
| """ | |
| import torch | |
| import textattack | |
| from textattack.shared import utils | |
| from .goal_function_result import GoalFunctionResult | |
| class ClassificationGoalFunctionResult(GoalFunctionResult): | |
| """Represents the result of a classification goal function.""" | |
| def __init__( | |
| self, | |
| attacked_text, | |
| raw_output, | |
| output, | |
| goal_status, | |
| score, | |
| num_queries, | |
| ground_truth_output, | |
| ): | |
| super().__init__( | |
| attacked_text, | |
| raw_output, | |
| output, | |
| goal_status, | |
| score, | |
| num_queries, | |
| ground_truth_output, | |
| goal_function_result_type="Classification", | |
| ) | |
| def _processed_output(self): | |
| """Takes a model output (like `1`) and returns the class labeled output | |
| (like `positive`), if possible. | |
| Also returns the associated color. | |
| """ | |
| output_label = self.raw_output.argmax() | |
| if self.attacked_text.attack_attrs.get("label_names") is not None: | |
| output = self.attacked_text.attack_attrs["label_names"][self.output] | |
| output = textattack.shared.utils.process_label_name(output) | |
| color = textattack.shared.utils.color_from_output(output, output_label) | |
| return output, color | |
| else: | |
| color = textattack.shared.utils.color_from_label(output_label) | |
| return output_label, color | |
| def get_text_color_input(self): | |
| """A string representing the color this result's changed portion should | |
| be if it represents the original input.""" | |
| _, color = self._processed_output | |
| return color | |
| def get_text_color_perturbed(self): | |
| """A string representing the color this result's changed portion should | |
| be if it represents the perturbed input.""" | |
| _, color = self._processed_output | |
| return color | |
| def get_colored_output(self, color_method=None): | |
| """Returns a string representation of this result's output, colored | |
| according to `color_method`.""" | |
| output_label = self.raw_output.argmax() | |
| confidence_score = self.raw_output[output_label] | |
| if isinstance(confidence_score, torch.Tensor): | |
| confidence_score = confidence_score.item() | |
| output, color = self._processed_output | |
| # concatenate with label and convert confidence score to percent, like '33%' | |
| output_str = f"{output} ({confidence_score:.0%})" | |
| return utils.color_text(output_str, color=color, method=color_method) | |