Spaces:
Runtime error
Runtime error
| """ | |
| Determine if an attack has been successful in targeted Classification | |
| ----------------------------------------------------------------------- | |
| """ | |
| from .classification_goal_function import ClassificationGoalFunction | |
| class TargetedClassification(ClassificationGoalFunction): | |
| """A targeted attack on classification models which attempts to maximize | |
| the score of the target label. | |
| Complete when the arget label is the predicted label. | |
| """ | |
| def __init__(self, *args, target_class=0, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.target_class = target_class | |
| def _is_goal_complete(self, model_output, _): | |
| return ( | |
| self.target_class == model_output.argmax() | |
| ) or self.ground_truth_output == self.target_class | |
| def _get_score(self, model_output, _): | |
| if self.target_class < 0 or self.target_class >= len(model_output): | |
| raise ValueError( | |
| f"target class set to {self.target_class} with {len(model_output)} classes." | |
| ) | |
| else: | |
| return model_output[self.target_class] | |
| def extra_repr_keys(self): | |
| if self.maximizable: | |
| return ["maximizable", "target_class"] | |
| else: | |
| return ["target_class"] | |