Spaces:
Sleeping
Sleeping
| """ | |
| LSTM 4 Classification | |
| --------------------------------------------------------------------- | |
| """ | |
| import json | |
| import os | |
| import torch | |
| from torch import nn as nn | |
| import textattack | |
| from textattack.model_args import TEXTATTACK_MODELS | |
| from textattack.models.helpers import GloveEmbeddingLayer | |
| from textattack.models.helpers.utils import load_cached_state_dict | |
| from textattack.shared import utils | |
| class LSTMForClassification(nn.Module): | |
| """A long short-term memory neural network for text classification. | |
| We use different versions of this network to pretrain models for | |
| text classification. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size=150, | |
| depth=1, | |
| dropout=0.3, | |
| num_labels=2, | |
| max_seq_length=128, | |
| model_path=None, | |
| emb_layer_trainable=True, | |
| ): | |
| super().__init__() | |
| self._config = { | |
| "architectures": "LSTMForClassification", | |
| "hidden_size": hidden_size, | |
| "depth": depth, | |
| "dropout": dropout, | |
| "num_labels": num_labels, | |
| "max_seq_length": max_seq_length, | |
| "model_path": model_path, | |
| "emb_layer_trainable": emb_layer_trainable, | |
| } | |
| if depth <= 1: | |
| # Fix error where we ask for non-zero dropout with only 1 layer. | |
| # nn.module.RNN won't add dropout for the last recurrent layer, | |
| # so if that's all we have, this will display a warning. | |
| dropout = 0 | |
| self.drop = nn.Dropout(dropout) | |
| self.emb_layer_trainable = emb_layer_trainable | |
| self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable) | |
| self.word2id = self.emb_layer.word2id | |
| self.encoder = nn.LSTM( | |
| input_size=self.emb_layer.n_d, | |
| hidden_size=hidden_size // 2, | |
| num_layers=depth, | |
| dropout=dropout, | |
| bidirectional=True, | |
| ) | |
| d_out = hidden_size | |
| self.out = nn.Linear(d_out, num_labels) | |
| self.tokenizer = textattack.models.tokenizers.GloveTokenizer( | |
| word_id_map=self.word2id, | |
| unk_token_id=self.emb_layer.oovid, | |
| pad_token_id=self.emb_layer.padid, | |
| max_length=max_seq_length, | |
| ) | |
| if model_path is not None: | |
| self.load_from_disk(model_path) | |
| self.eval() | |
| def load_from_disk(self, model_path): | |
| # TODO: Consider removing this in the future as well as loading via `model_path` in `__init__`. | |
| import warnings | |
| warnings.warn( | |
| "`load_from_disk` method is deprecated. Please save and load using `save_pretrained` and `from_pretrained` methods.", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| self.load_state_dict(load_cached_state_dict(model_path)) | |
| self.eval() | |
| def save_pretrained(self, output_path): | |
| if not os.path.exists(output_path): | |
| os.makedirs(output_path) | |
| state_dict = {k: v.cpu() for k, v in self.state_dict().items()} | |
| torch.save( | |
| state_dict, | |
| os.path.join(output_path, "pytorch_model.bin"), | |
| ) | |
| with open(os.path.join(output_path, "config.json"), "w") as f: | |
| json.dump(self._config, f) | |
| def from_pretrained(cls, name_or_path): | |
| """Load trained LSTM model by name or from path. | |
| Args: | |
| name_or_path (:obj:`str`): Name of the model (e.g. "lstm-imdb") or model saved via :meth:`save_pretrained`. | |
| Returns: | |
| :class:`~textattack.models.helpers.LSTMForClassification` model | |
| """ | |
| if name_or_path in TEXTATTACK_MODELS: | |
| # path = utils.download_if_needed(TEXTATTACK_MODELS[name_or_path]) | |
| path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path]) | |
| else: | |
| path = name_or_path | |
| config_path = os.path.join(path, "config.json") | |
| if os.path.exists(config_path): | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| else: | |
| # Default config | |
| config = { | |
| "architectures": "LSTMForClassification", | |
| "hidden_size": 150, | |
| "depth": 1, | |
| "dropout": 0.3, | |
| "num_labels": 2, | |
| "max_seq_length": 128, | |
| "model_path": None, | |
| "emb_layer_trainable": True, | |
| } | |
| del config["architectures"] | |
| model = cls(**config) | |
| state_dict = load_cached_state_dict(path) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def forward(self, _input): | |
| # ensure RNN module weights are part of single contiguous chunk of memory | |
| self.encoder.flatten_parameters() | |
| emb = self.emb_layer(_input.t()) | |
| emb = self.drop(emb) | |
| output, hidden = self.encoder(emb) | |
| output = torch.max(output, dim=0)[0] | |
| output = self.drop(output) | |
| pred = self.out(output) | |
| return pred | |
| def get_input_embeddings(self): | |
| return self.emb_layer.embedding | |