Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Base classes for various fairseq models. | |
| """ | |
| import logging | |
| from argparse import Namespace | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fairseq import utils | |
| from fairseq.data import Dictionary | |
| from fairseq.dataclass.utils import ( | |
| convert_namespace_to_omegaconf, | |
| gen_parser_from_dataclass, | |
| ) | |
| from fairseq.models import FairseqDecoder, FairseqEncoder | |
| from omegaconf import DictConfig | |
| from torch import Tensor | |
| logger = logging.getLogger(__name__) | |
| def check_type(module, expected_type): | |
| if hasattr(module, "unwrapped_module"): | |
| assert isinstance( | |
| module.unwrapped_module, expected_type | |
| ), f"{type(module.unwrapped_module)} != {expected_type}" | |
| else: | |
| assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" | |
| class BaseFairseqModel(nn.Module): | |
| """Base class for fairseq models.""" | |
| def __init__(self): | |
| super().__init__() | |
| self._is_generation_fast = False | |
| def add_args(cls, parser): | |
| """Add model-specific arguments to the parser.""" | |
| dc = getattr(cls, "__dataclass", None) | |
| if dc is not None: | |
| # do not set defaults so that settings defaults from various architectures still works | |
| gen_parser_from_dataclass(parser, dc(), delete_default=True) | |
| def build_model(cls, args, task): | |
| """Build a new model instance.""" | |
| raise NotImplementedError("Model must implement the build_model method") | |
| def get_targets(self, sample, net_output): | |
| """Get targets from either the sample or the net's output.""" | |
| return sample["target"] | |
| def get_normalized_probs( | |
| self, | |
| net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], | |
| log_probs: bool, | |
| sample: Optional[Dict[str, Tensor]] = None, | |
| ): | |
| """Get normalized probabilities (or log probs) from a net's output.""" | |
| return self.get_normalized_probs_scriptable(net_output, log_probs, sample) | |
| # TorchScript doesn't support super() method so that the scriptable Subclass | |
| # can't access the base class model in Torchscript. | |
| # Current workaround is to add a helper function with different name and | |
| # call the helper function from scriptable Subclass. | |
| def get_normalized_probs_scriptable( | |
| self, | |
| net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], | |
| log_probs: bool, | |
| sample: Optional[Dict[str, Tensor]] = None, | |
| ): | |
| """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" | |
| if hasattr(self, "decoder"): | |
| return self.decoder.get_normalized_probs(net_output, log_probs, sample) | |
| elif torch.is_tensor(net_output): | |
| # syntactic sugar for simple models which don't have a decoder | |
| # (e.g., the classification tutorial) | |
| logits = net_output.float() | |
| if log_probs: | |
| return F.log_softmax(logits, dim=-1) | |
| else: | |
| return F.softmax(logits, dim=-1) | |
| raise NotImplementedError | |
| def extract_features(self, *args, **kwargs): | |
| """Similar to *forward* but only return features.""" | |
| return self(*args, **kwargs) | |
| def max_positions(self): | |
| """Maximum length supported by the model.""" | |
| return None | |
| def load_state_dict( | |
| self, | |
| state_dict, | |
| strict=True, | |
| model_cfg: Optional[DictConfig] = None, | |
| args: Optional[Namespace] = None, | |
| ): | |
| """Copies parameters and buffers from *state_dict* into this module and | |
| its descendants. | |
| Overrides the method in :class:`nn.Module`. Compared with that method | |
| this additionally "upgrades" *state_dicts* from old checkpoints. | |
| """ | |
| if model_cfg is None and args is not None: | |
| logger.warn( | |
| "using 'args' is deprecated, please update your code to use dataclass config" | |
| ) | |
| model_cfg = convert_namespace_to_omegaconf(args).model | |
| self.upgrade_state_dict(state_dict) | |
| from fairseq.checkpoint_utils import prune_state_dict | |
| new_state_dict = prune_state_dict(state_dict, model_cfg) | |
| return super().load_state_dict(new_state_dict, strict) | |
| def upgrade_state_dict(self, state_dict): | |
| """Upgrade old state dicts to work with newer code.""" | |
| self.upgrade_state_dict_named(state_dict, "") | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """Upgrade old state dicts to work with newer code. | |
| Args: | |
| state_dict (dict): state dictionary to upgrade, in place | |
| name (str): the state dict key corresponding to the current module | |
| """ | |
| assert state_dict is not None | |
| def do_upgrade(m, prefix): | |
| if len(prefix) > 0: | |
| prefix += "." | |
| for n, c in m.named_children(): | |
| name = prefix + n | |
| if hasattr(c, "upgrade_state_dict_named"): | |
| c.upgrade_state_dict_named(state_dict, name) | |
| elif hasattr(c, "upgrade_state_dict"): | |
| c.upgrade_state_dict(state_dict) | |
| do_upgrade(c, name) | |
| do_upgrade(self, name) | |
| def set_num_updates(self, num_updates): | |
| """State from trainer to pass along to model at every update.""" | |
| for m in self.modules(): | |
| if hasattr(m, "set_num_updates") and m != self: | |
| m.set_num_updates(num_updates) | |
| def set_epoch(self, epoch): | |
| for m in self.modules(): | |
| if hasattr(m, "set_epoch") and m != self: | |
| m.set_epoch(epoch) | |
| def prepare_for_inference_(self, cfg: DictConfig): | |
| """Prepare model for inference.""" | |
| kwargs = {} | |
| kwargs["beamable_mm_beam_size"] = ( | |
| None | |
| if getattr(cfg.generation, "no_beamable_mm", False) | |
| else getattr(cfg.generation, "beam", 5) | |
| ) | |
| kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) | |
| if getattr(cfg.generation, "retain_dropout", False): | |
| kwargs["retain_dropout"] = cfg.generation.retain_dropout | |
| kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules | |
| self.make_generation_fast_(**kwargs) | |
| def make_generation_fast_(self, **kwargs): | |
| """ | |
| Legacy entry point to optimize model for faster generation. | |
| Prefer prepare_for_inference_. | |
| """ | |
| if self._is_generation_fast: | |
| return # only apply once | |
| self._is_generation_fast = True | |
| # remove weight norm from all modules in the network | |
| def apply_remove_weight_norm(module): | |
| try: | |
| nn.utils.remove_weight_norm(module) | |
| except (AttributeError, ValueError): # this module didn't have weight norm | |
| return | |
| self.apply(apply_remove_weight_norm) | |
| def apply_make_generation_fast_(module, prefix): | |
| if len(prefix) > 0: | |
| prefix += "." | |
| base_func = BaseFairseqModel.make_generation_fast_ | |
| for n, m in module.named_modules(): | |
| if ( | |
| m != self | |
| and hasattr(m, "make_generation_fast_") | |
| # don't call this implementation again, e.g., if | |
| # children modules also inherit from BaseFairseqModel | |
| and m.make_generation_fast_.__func__ is not base_func | |
| ): | |
| name = prefix + n | |
| m.make_generation_fast_(name=name, **kwargs) | |
| apply_make_generation_fast_(self, "") | |
| def train(mode=True): | |
| if mode: | |
| raise RuntimeError("cannot train after make_generation_fast") | |
| # this model should no longer be used for training | |
| self.eval() | |
| self.train = train | |
| def prepare_for_onnx_export_(self, **kwargs): | |
| """Make model exportable via ONNX trace.""" | |
| seen = set() | |
| def apply_prepare_for_onnx_export_(module): | |
| if ( | |
| module != self | |
| and hasattr(module, "prepare_for_onnx_export_") | |
| and module not in seen | |
| ): | |
| seen.add(module) | |
| module.prepare_for_onnx_export_(**kwargs) | |
| self.apply(apply_prepare_for_onnx_export_) | |
| def from_pretrained( | |
| cls, | |
| model_name_or_path, | |
| checkpoint_file="model.pt", | |
| data_name_or_path=".", | |
| **kwargs, | |
| ): | |
| """ | |
| Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model | |
| file. Downloads and caches the pre-trained model file if needed. | |
| The base implementation returns a | |
| :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to | |
| generate translations or sample from language models. The underlying | |
| :class:`~fairseq.models.FairseqModel` can be accessed via the | |
| *generator.models* attribute. | |
| Other models may override this to implement custom hub interfaces. | |
| Args: | |
| model_name_or_path (str): either the name of a pre-trained model to | |
| load or a path/URL to a pre-trained model state dict | |
| checkpoint_file (str, optional): colon-separated list of checkpoint | |
| files in the model archive to ensemble (default: 'model.pt') | |
| data_name_or_path (str, optional): point args.data to the archive | |
| at the given path/URL. Can start with '.' or './' to reuse the | |
| model archive path. | |
| """ | |
| from fairseq import hub_utils | |
| x = hub_utils.from_pretrained( | |
| model_name_or_path, | |
| checkpoint_file, | |
| data_name_or_path, | |
| archive_map=cls.hub_models(), | |
| **kwargs, | |
| ) | |
| logger.info(x["args"]) | |
| return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) | |
| def hub_models(cls): | |
| return {} | |
| class FairseqEncoderDecoderModel(BaseFairseqModel): | |
| """Base class for encoder-decoder models. | |
| Args: | |
| encoder (FairseqEncoder): the encoder | |
| decoder (FairseqDecoder): the decoder | |
| """ | |
| def __init__(self, encoder, decoder): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| check_type(self.encoder, FairseqEncoder) | |
| check_type(self.decoder, FairseqDecoder) | |
| def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): | |
| """ | |
| Run the forward pass for an encoder-decoder model. | |
| First feed a batch of source tokens through the encoder. Then, feed the | |
| encoder output and previous decoder outputs (i.e., teacher forcing) to | |
| the decoder to produce the next outputs:: | |
| encoder_out = self.encoder(src_tokens, src_lengths) | |
| return self.decoder(prev_output_tokens, encoder_out) | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (LongTensor): source sentence lengths of shape `(batch)` | |
| prev_output_tokens (LongTensor): previous decoder outputs of shape | |
| `(batch, tgt_len)`, for teacher forcing | |
| Returns: | |
| tuple: | |
| - the decoder's output of shape `(batch, tgt_len, vocab)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) | |
| decoder_out = self.decoder( | |
| prev_output_tokens, encoder_out=encoder_out, **kwargs | |
| ) | |
| return decoder_out | |
| def forward_decoder(self, prev_output_tokens, **kwargs): | |
| return self.decoder(prev_output_tokens, **kwargs) | |
| def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): | |
| """ | |
| Similar to *forward* but only return features. | |
| Returns: | |
| tuple: | |
| - the decoder's features of shape `(batch, tgt_len, embed_dim)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) | |
| features = self.decoder.extract_features( | |
| prev_output_tokens, encoder_out=encoder_out, **kwargs | |
| ) | |
| return features | |
| def output_layer(self, features, **kwargs): | |
| """Project features to the default output size (typically vocabulary size).""" | |
| return self.decoder.output_layer(features, **kwargs) | |
| def max_positions(self): | |
| """Maximum length supported by the model.""" | |
| return (self.encoder.max_positions(), self.decoder.max_positions()) | |
| def max_decoder_positions(self): | |
| """Maximum length supported by the decoder.""" | |
| return self.decoder.max_positions() | |
| class FairseqModel(FairseqEncoderDecoderModel): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| utils.deprecation_warning( | |
| "FairseqModel is deprecated, please use FairseqEncoderDecoderModel " | |
| "or BaseFairseqModel instead", | |
| stacklevel=4, | |
| ) | |
| class FairseqMultiModel(BaseFairseqModel): | |
| """Base class for combining multiple encoder-decoder models.""" | |
| def __init__(self, encoders, decoders): | |
| super().__init__() | |
| assert encoders.keys() == decoders.keys() | |
| self.keys = list(encoders.keys()) | |
| for key in self.keys: | |
| check_type(encoders[key], FairseqEncoder) | |
| check_type(decoders[key], FairseqDecoder) | |
| self.models = nn.ModuleDict( | |
| { | |
| key: FairseqEncoderDecoderModel(encoders[key], decoders[key]) | |
| for key in self.keys | |
| } | |
| ) | |
| def build_shared_embeddings( | |
| dicts: Dict[str, Dictionary], | |
| langs: List[str], | |
| embed_dim: int, | |
| build_embedding: callable, | |
| pretrained_embed_path: Optional[str] = None, | |
| ): | |
| """ | |
| Helper function to build shared embeddings for a set of languages after | |
| checking that all dicts corresponding to those languages are equivalent. | |
| Args: | |
| dicts: Dict of lang_id to its corresponding Dictionary | |
| langs: languages that we want to share embeddings for | |
| embed_dim: embedding dimension | |
| build_embedding: callable function to actually build the embedding | |
| pretrained_embed_path: Optional path to load pretrained embeddings | |
| """ | |
| shared_dict = dicts[langs[0]] | |
| if any(dicts[lang] != shared_dict for lang in langs): | |
| raise ValueError( | |
| "--share-*-embeddings requires a joined dictionary: " | |
| "--share-encoder-embeddings requires a joined source " | |
| "dictionary, --share-decoder-embeddings requires a joined " | |
| "target dictionary, and --share-all-embeddings requires a " | |
| "joint source + target dictionary." | |
| ) | |
| return build_embedding(shared_dict, embed_dim, pretrained_embed_path) | |
| def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): | |
| raise NotImplementedError | |
| def max_positions(self): | |
| """Maximum length supported by the model.""" | |
| return { | |
| key: ( | |
| self.models[key].encoder.max_positions(), | |
| self.models[key].decoder.max_positions(), | |
| ) | |
| for key in self.keys | |
| } | |
| def max_decoder_positions(self): | |
| """Maximum length supported by the decoder.""" | |
| return min(model.decoder.max_positions() for model in self.models.values()) | |
| def encoder(self): | |
| return self.models[self.keys[0]].encoder | |
| def decoder(self): | |
| return self.models[self.keys[0]].decoder | |
| def forward_decoder(self, prev_output_tokens, **kwargs): | |
| return self.decoder(prev_output_tokens, **kwargs) | |
| def load_state_dict( | |
| self, | |
| state_dict, | |
| strict=True, | |
| model_cfg=None, | |
| args: Optional[Namespace] = None, | |
| ): | |
| """Copies parameters and buffers from *state_dict* into this module and | |
| its descendants. | |
| Overrides the method in :class:`nn.Module`. Compared with that method | |
| this additionally "upgrades" *state_dicts* from old checkpoints. | |
| """ | |
| if model_cfg is None and args is not None: | |
| logger.warn( | |
| "using 'args' is deprecated, please update your code to use dataclass config" | |
| ) | |
| model_cfg = convert_namespace_to_omegaconf(args).model | |
| self.upgrade_state_dict(state_dict) | |
| from fairseq.checkpoint_utils import prune_state_dict | |
| new_state_dict = prune_state_dict(state_dict, model_cfg) | |
| return super().load_state_dict(new_state_dict, strict) | |
| class FairseqLanguageModel(BaseFairseqModel): | |
| """Base class for decoder-only models. | |
| Args: | |
| decoder (FairseqDecoder): the decoder | |
| """ | |
| def __init__(self, decoder): | |
| super().__init__() | |
| self.decoder = decoder | |
| check_type(self.decoder, FairseqDecoder) | |
| def forward(self, src_tokens, **kwargs): | |
| """ | |
| Run the forward pass for a decoder-only model. | |
| Feeds a batch of tokens through the decoder to predict the next tokens. | |
| Args: | |
| src_tokens (LongTensor): tokens on which to condition the decoder, | |
| of shape `(batch, tgt_len)` | |
| src_lengths (LongTensor): source sentence lengths of shape `(batch)` | |
| Returns: | |
| tuple: | |
| - the decoder's output of shape `(batch, seq_len, vocab)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| return self.decoder(src_tokens, **kwargs) | |
| def forward_decoder(self, prev_output_tokens, **kwargs): | |
| return self.decoder(prev_output_tokens, **kwargs) | |
| def extract_features(self, src_tokens, **kwargs): | |
| """ | |
| Similar to *forward* but only return features. | |
| Returns: | |
| tuple: | |
| - the decoder's features of shape `(batch, seq_len, embed_dim)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| return self.decoder.extract_features(src_tokens, **kwargs) | |
| def output_layer(self, features, **kwargs): | |
| """Project features to the default output size (typically vocabulary size).""" | |
| return self.decoder.output_layer(features, **kwargs) | |
| def max_positions(self): | |
| """Maximum length supported by the model.""" | |
| return self.decoder.max_positions() | |
| def max_decoder_positions(self): | |
| """Maximum length supported by the decoder.""" | |
| return self.decoder.max_positions() | |
| def supported_targets(self): | |
| return {"future"} | |
| class FairseqEncoderModel(BaseFairseqModel): | |
| """Base class for encoder-only models. | |
| Args: | |
| encoder (FairseqEncoder): the encoder | |
| """ | |
| def __init__(self, encoder): | |
| super().__init__() | |
| self.encoder = encoder | |
| check_type(self.encoder, FairseqEncoder) | |
| def forward(self, src_tokens, src_lengths, **kwargs): | |
| """ | |
| Run the forward pass for a encoder-only model. | |
| Feeds a batch of tokens through the encoder to generate features. | |
| Args: | |
| src_tokens (LongTensor): input tokens of shape `(batch, src_len)` | |
| src_lengths (LongTensor): source sentence lengths of shape `(batch)` | |
| Returns: | |
| the encoder's output, typically of shape `(batch, src_len, features)` | |
| """ | |
| return self.encoder(src_tokens, src_lengths, **kwargs) | |
| def get_normalized_probs(self, net_output, log_probs, sample=None): | |
| """Get normalized probabilities (or log probs) from a net's output.""" | |
| encoder_out = net_output["encoder_out"] | |
| if torch.is_tensor(encoder_out): | |
| logits = encoder_out.float() | |
| if log_probs: | |
| return F.log_softmax(logits, dim=-1) | |
| else: | |
| return F.softmax(logits, dim=-1) | |
| raise NotImplementedError | |
| def max_positions(self): | |
| """Maximum length supported by the model.""" | |
| return self.encoder.max_positions() | |