Spaces:
Runtime error
Runtime error
| """ | |
| Word CNN for Classification | |
| --------------------------------------------------------------------- | |
| """ | |
| import json | |
| import os | |
| import torch | |
| from torch import nn as nn | |
| from torch.nn import functional as F | |
| import textattack | |
| from textattack.model_args import TEXTATTACK_MODELS | |
| from textattack.models.helpers import GloveEmbeddingLayer | |
| from textattack.models.helpers.utils import load_cached_state_dict | |
| from textattack.shared import utils | |
| class WordCNNForClassification(nn.Module): | |
| """A convolutional neural network for text classification. | |
| We use different versions of this network to pretrain models for | |
| text classification. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size=150, | |
| dropout=0.3, | |
| num_labels=2, | |
| max_seq_length=128, | |
| model_path=None, | |
| emb_layer_trainable=True, | |
| ): | |
| super().__init__() | |
| self._config = { | |
| "architectures": "WordCNNForClassification", | |
| "hidden_size": hidden_size, | |
| "dropout": dropout, | |
| "num_labels": num_labels, | |
| "max_seq_length": max_seq_length, | |
| "model_path": model_path, | |
| "emb_layer_trainable": emb_layer_trainable, | |
| } | |
| self.drop = nn.Dropout(dropout) | |
| self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable) | |
| self.word2id = self.emb_layer.word2id | |
| self.encoder = CNNTextLayer( | |
| self.emb_layer.n_d, widths=[3, 4, 5], filters=hidden_size | |
| ) | |
| d_out = 3 * hidden_size | |
| self.out = nn.Linear(d_out, num_labels) | |
| self.tokenizer = textattack.models.tokenizers.GloveTokenizer( | |
| word_id_map=self.word2id, | |
| unk_token_id=self.emb_layer.oovid, | |
| pad_token_id=self.emb_layer.padid, | |
| max_length=max_seq_length, | |
| ) | |
| if model_path is not None: | |
| self.load_from_disk(model_path) | |
| self.eval() | |
| def load_from_disk(self, model_path): | |
| # TODO: Consider removing this in the future as well as loading via `model_path` in `__init__`. | |
| import warnings | |
| warnings.warn( | |
| "`load_from_disk` method is deprecated. Please save and load using `save_pretrained` and `from_pretrained` methods.", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| self.load_state_dict(load_cached_state_dict(model_path)) | |
| self.eval() | |
| def save_pretrained(self, output_path): | |
| if not os.path.exists(output_path): | |
| os.makedirs(output_path) | |
| state_dict = {k: v.cpu() for k, v in self.state_dict().items()} | |
| torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin")) | |
| with open(os.path.join(output_path, "config.json"), "w") as f: | |
| json.dump(self._config, f) | |
| def from_pretrained(cls, name_or_path): | |
| """Load trained Word CNN model by name or from path. | |
| Args: | |
| name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`. | |
| Returns: | |
| :class:`~textattack.models.helpers.WordCNNForClassification` model | |
| """ | |
| if name_or_path in TEXTATTACK_MODELS: | |
| path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path]) | |
| else: | |
| path = name_or_path | |
| config_path = os.path.join(path, "config.json") | |
| if os.path.exists(config_path): | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| else: | |
| # Default config | |
| config = { | |
| "architectures": "WordCNNForClassification", | |
| "hidden_size": 150, | |
| "dropout": 0.3, | |
| "num_labels": 2, | |
| "max_seq_length": 128, | |
| "model_path": None, | |
| "emb_layer_trainable": True, | |
| } | |
| del config["architectures"] | |
| model = cls(**config) | |
| state_dict = load_cached_state_dict(path) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def forward(self, _input): | |
| emb = self.emb_layer(_input) | |
| emb = self.drop(emb) | |
| output = self.encoder(emb) | |
| output = self.drop(output) | |
| pred = self.out(output) | |
| return pred | |
| def get_input_embeddings(self): | |
| return self.emb_layer.embedding | |
| class CNNTextLayer(nn.Module): | |
| def __init__(self, n_in, widths=[3, 4, 5], filters=100): | |
| super().__init__() | |
| Ci = 1 | |
| Co = filters | |
| h = n_in | |
| self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (w, h)) for w in widths]) | |
| def forward(self, x): | |
| x = x.unsqueeze(1) # (batch, Ci, len, d) | |
| x = [ | |
| F.relu(conv(x)).squeeze(3) for conv in self.convs1 | |
| ] # [(batch, Co, len), ...] | |
| x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N,Co), ...] | |
| x = torch.cat(x, 1) | |
| return x | |