Spaces:
Runtime error
Runtime error
| from textattack.models.wrappers import HuggingFaceModelWrapper | |
| class TADModelWrapper(HuggingFaceModelWrapper): | |
| """Transformers sentiment analysis pipeline returns a list of responses | |
| like | |
| [{'label': 'POSITIVE', 'score': 0.7817379832267761}] | |
| We need to convert that to a format TextAttack understands, like | |
| [[0.218262017, 0.7817379832267761] | |
| """ | |
| def __init__(self, model): | |
| self.model = model # pipeline = pipeline | |
| def __call__(self, text_inputs, **kwargs): | |
| outputs = [] | |
| for text_input in text_inputs: | |
| raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) | |
| outputs.append(raw_outputs["probs"]) | |
| return outputs | |