Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from typing import Any, Dict | |
| from fairseq import checkpoint_utils | |
| from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary | |
| from fairseq.models import register_model, register_model_architecture | |
| from fairseq.models.transformer import ( | |
| TransformerDecoder, | |
| TransformerEncoder, | |
| TransformerModel, | |
| base_architecture as transformer_base_architecture, | |
| ) | |
| class TransformerFromPretrainedXLMModel(TransformerModel): | |
| def add_args(parser): | |
| """Add model-specific arguments to the parser.""" | |
| TransformerModel.add_args(parser) | |
| parser.add_argument( | |
| "--pretrained-xlm-checkpoint", | |
| type=str, | |
| metavar="STR", | |
| help="XLM model to use for initializing transformer encoder and/or decoder", | |
| ) | |
| parser.add_argument( | |
| "--init-encoder-only", | |
| action="store_true", | |
| help="if set, don't load the XLM weights and embeddings into decoder", | |
| ) | |
| parser.add_argument( | |
| "--init-decoder-only", | |
| action="store_true", | |
| help="if set, don't load the XLM weights and embeddings into encoder", | |
| ) | |
| def build_model(self, args, task, cls_dictionary=MaskedLMDictionary): | |
| assert hasattr(args, "pretrained_xlm_checkpoint"), ( | |
| "You must specify a path for --pretrained-xlm-checkpoint to use " | |
| "--arch transformer_from_pretrained_xlm" | |
| ) | |
| assert isinstance(task.source_dictionary, cls_dictionary) and isinstance( | |
| task.target_dictionary, cls_dictionary | |
| ), ( | |
| "You should use a MaskedLMDictionary when using --arch " | |
| "transformer_from_pretrained_xlm because the pretrained XLM model " | |
| "was trained using data binarized with MaskedLMDictionary. " | |
| "For translation, you may want to use --task " | |
| "translation_from_pretrained_xlm" | |
| ) | |
| assert not ( | |
| getattr(args, "init_encoder_only", False) | |
| and getattr(args, "init_decoder_only", False) | |
| ), "Only one of --init-encoder-only and --init-decoder-only can be set." | |
| return super().build_model(args, task) | |
| def build_encoder(cls, args, src_dict, embed_tokens): | |
| return TransformerEncoderFromPretrainedXLM(args, src_dict, embed_tokens) | |
| def build_decoder(cls, args, tgt_dict, embed_tokens): | |
| return TransformerDecoderFromPretrainedXLM(args, tgt_dict, embed_tokens) | |
| def upgrade_state_dict_with_xlm_weights( | |
| state_dict: Dict[str, Any], pretrained_xlm_checkpoint: str | |
| ) -> Dict[str, Any]: | |
| """ | |
| Load XLM weights into a Transformer encoder or decoder model. | |
| Args: | |
| state_dict: state dict for either TransformerEncoder or | |
| TransformerDecoder | |
| pretrained_xlm_checkpoint: checkpoint to load XLM weights from | |
| Raises: | |
| AssertionError: If architecture (num layers, attention heads, etc.) | |
| does not match between the current Transformer encoder or | |
| decoder and the pretrained_xlm_checkpoint | |
| """ | |
| if not os.path.exists(pretrained_xlm_checkpoint): | |
| raise IOError("Model file not found: {}".format(pretrained_xlm_checkpoint)) | |
| state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) | |
| xlm_state_dict = state["model"] | |
| for key in xlm_state_dict.keys(): | |
| for search_key in ["embed_tokens", "embed_positions", "layers"]: | |
| if search_key in key: | |
| subkey = key[key.find(search_key) :] | |
| assert subkey in state_dict, ( | |
| "{} Transformer encoder / decoder " | |
| "state_dict does not contain {}. Cannot " | |
| "load {} from pretrained XLM checkpoint " | |
| "{} into Transformer.".format( | |
| str(state_dict.keys()), subkey, key, pretrained_xlm_checkpoint | |
| ) | |
| ) | |
| state_dict[subkey] = xlm_state_dict[key] | |
| return state_dict | |
| class TransformerEncoderFromPretrainedXLM(TransformerEncoder): | |
| def __init__(self, args, dictionary, embed_tokens): | |
| super().__init__(args, dictionary, embed_tokens) | |
| if getattr(args, "init_decoder_only", False): | |
| # Don't load XLM weights for encoder if --init-decoder-only | |
| return | |
| assert hasattr(args, "pretrained_xlm_checkpoint"), ( | |
| "--pretrained-xlm-checkpoint must be specified to load Transformer " | |
| "encoder from pretrained XLM" | |
| ) | |
| xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights( | |
| state_dict=self.state_dict(), | |
| pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint, | |
| ) | |
| self.load_state_dict(xlm_loaded_state_dict, strict=True) | |
| class TransformerDecoderFromPretrainedXLM(TransformerDecoder): | |
| def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): | |
| super().__init__(args, dictionary, embed_tokens, no_encoder_attn) | |
| if getattr(args, "init_encoder_only", False): | |
| # Don't load XLM weights for decoder if --init-encoder-only | |
| return | |
| assert hasattr(args, "pretrained_xlm_checkpoint"), ( | |
| "--pretrained-xlm-checkpoint must be specified to load Transformer " | |
| "decoder from pretrained XLM" | |
| ) | |
| xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights( | |
| state_dict=self.state_dict(), | |
| pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint, | |
| ) | |
| self.load_state_dict(xlm_loaded_state_dict, strict=True) | |
| def base_architecture(args): | |
| transformer_base_architecture(args) | |