Spaces:
Runtime error
Runtime error
| """ | |
| TensorFlow Model Wrapper | |
| -------------------------- | |
| """ | |
| import numpy as np | |
| from .model_wrapper import ModelWrapper | |
| class TensorFlowModelWrapper(ModelWrapper): | |
| """Loads a TensorFlow model and tokenizer. | |
| TensorFlow models can use many different architectures and | |
| tokenization strategies. This assumes that the model takes an | |
| np.array of strings as input and returns a tf.Tensor of outputs, as | |
| is typical with Keras modules. You may need to subclass this for | |
| models that have dedicated tokenizers or otherwise take input | |
| differently. | |
| """ | |
| def __init__(self, model): | |
| self.model = model | |
| def __call__(self, text_input_list, **kwargs): | |
| text_array = np.array(text_input_list) | |
| preds = self.model(text_array) | |
| return preds.numpy() | |
| def get_grad(self, text_input): | |
| raise NotImplementedError() | |