Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import OrderedDict | |
| from fairseq import utils | |
| from fairseq.models import ( | |
| FairseqMultiModel, | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| from fairseq.models.transformer import ( | |
| Embedding, | |
| TransformerDecoder, | |
| TransformerEncoder, | |
| TransformerModel, | |
| base_architecture, | |
| ) | |
| from fairseq.utils import safe_hasattr | |
| class MultilingualTransformerModel(FairseqMultiModel): | |
| """Train Transformer models for multiple language pairs simultaneously. | |
| Requires `--task multilingual_translation`. | |
| We inherit all arguments from TransformerModel and assume that all language | |
| pairs use a single Transformer architecture. In addition, we provide several | |
| options that are specific to the multilingual setting. | |
| Args: | |
| --share-encoder-embeddings: share encoder embeddings across all source languages | |
| --share-decoder-embeddings: share decoder embeddings across all target languages | |
| --share-encoders: share all encoder params (incl. embeddings) across all source languages | |
| --share-decoders: share all decoder params (incl. embeddings) across all target languages | |
| """ | |
| def __init__(self, encoders, decoders): | |
| super().__init__(encoders, decoders) | |
| def add_args(parser): | |
| """Add model-specific arguments to the parser.""" | |
| TransformerModel.add_args(parser) | |
| parser.add_argument( | |
| "--share-encoder-embeddings", | |
| action="store_true", | |
| help="share encoder embeddings across languages", | |
| ) | |
| parser.add_argument( | |
| "--share-decoder-embeddings", | |
| action="store_true", | |
| help="share decoder embeddings across languages", | |
| ) | |
| parser.add_argument( | |
| "--share-encoders", | |
| action="store_true", | |
| help="share encoders across languages", | |
| ) | |
| parser.add_argument( | |
| "--share-decoders", | |
| action="store_true", | |
| help="share decoders across languages", | |
| ) | |
| def build_model(cls, args, task): | |
| """Build a new model instance.""" | |
| from fairseq.tasks.multilingual_translation import MultilingualTranslationTask | |
| assert isinstance(task, MultilingualTranslationTask) | |
| # make sure all arguments are present in older models | |
| base_multilingual_architecture(args) | |
| if not safe_hasattr(args, "max_source_positions"): | |
| args.max_source_positions = 1024 | |
| if not safe_hasattr(args, "max_target_positions"): | |
| args.max_target_positions = 1024 | |
| src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] | |
| tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] | |
| if args.share_encoders: | |
| args.share_encoder_embeddings = True | |
| if args.share_decoders: | |
| args.share_decoder_embeddings = True | |
| def build_embedding(dictionary, embed_dim, path=None): | |
| num_embeddings = len(dictionary) | |
| padding_idx = dictionary.pad() | |
| emb = Embedding(num_embeddings, embed_dim, padding_idx) | |
| # if provided, load from preloaded dictionaries | |
| if path: | |
| embed_dict = utils.parse_embedding(path) | |
| utils.load_embedding(embed_dict, dictionary, emb) | |
| return emb | |
| # build shared embeddings (if applicable) | |
| shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None | |
| if args.share_all_embeddings: | |
| if args.encoder_embed_dim != args.decoder_embed_dim: | |
| raise ValueError( | |
| "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" | |
| ) | |
| if args.decoder_embed_path and ( | |
| args.decoder_embed_path != args.encoder_embed_path | |
| ): | |
| raise ValueError( | |
| "--share-all-embeddings not compatible with --decoder-embed-path" | |
| ) | |
| shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( | |
| dicts=task.dicts, | |
| langs=task.langs, | |
| embed_dim=args.encoder_embed_dim, | |
| build_embedding=build_embedding, | |
| pretrained_embed_path=args.encoder_embed_path, | |
| ) | |
| shared_decoder_embed_tokens = shared_encoder_embed_tokens | |
| args.share_decoder_input_output_embed = True | |
| else: | |
| if args.share_encoder_embeddings: | |
| shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( | |
| dicts=task.dicts, | |
| langs=src_langs, | |
| embed_dim=args.encoder_embed_dim, | |
| build_embedding=build_embedding, | |
| pretrained_embed_path=args.encoder_embed_path, | |
| ) | |
| if args.share_decoder_embeddings: | |
| shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( | |
| dicts=task.dicts, | |
| langs=tgt_langs, | |
| embed_dim=args.decoder_embed_dim, | |
| build_embedding=build_embedding, | |
| pretrained_embed_path=args.decoder_embed_path, | |
| ) | |
| # encoders/decoders for each language | |
| lang_encoders, lang_decoders = {}, {} | |
| def get_encoder(lang): | |
| if lang not in lang_encoders: | |
| if shared_encoder_embed_tokens is not None: | |
| encoder_embed_tokens = shared_encoder_embed_tokens | |
| else: | |
| encoder_embed_tokens = build_embedding( | |
| task.dicts[lang], | |
| args.encoder_embed_dim, | |
| args.encoder_embed_path, | |
| ) | |
| lang_encoders[lang] = cls._get_module_class( | |
| True, args, task.dicts[lang], encoder_embed_tokens, src_langs | |
| ) | |
| return lang_encoders[lang] | |
| def get_decoder(lang): | |
| if lang not in lang_decoders: | |
| if shared_decoder_embed_tokens is not None: | |
| decoder_embed_tokens = shared_decoder_embed_tokens | |
| else: | |
| decoder_embed_tokens = build_embedding( | |
| task.dicts[lang], | |
| args.decoder_embed_dim, | |
| args.decoder_embed_path, | |
| ) | |
| lang_decoders[lang] = cls._get_module_class( | |
| False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs | |
| ) | |
| return lang_decoders[lang] | |
| # shared encoders/decoders (if applicable) | |
| shared_encoder, shared_decoder = None, None | |
| if args.share_encoders: | |
| shared_encoder = get_encoder(src_langs[0]) | |
| if args.share_decoders: | |
| shared_decoder = get_decoder(tgt_langs[0]) | |
| encoders, decoders = OrderedDict(), OrderedDict() | |
| for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): | |
| encoders[lang_pair] = ( | |
| shared_encoder if shared_encoder is not None else get_encoder(src) | |
| ) | |
| decoders[lang_pair] = ( | |
| shared_decoder if shared_decoder is not None else get_decoder(tgt) | |
| ) | |
| return MultilingualTransformerModel(encoders, decoders) | |
| def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): | |
| module_class = TransformerEncoder if is_encoder else TransformerDecoder | |
| return module_class(args, lang_dict, embed_tokens) | |
| def load_state_dict(self, state_dict, strict=True, model_cfg=None): | |
| state_dict_subset = state_dict.copy() | |
| for k, _ in state_dict.items(): | |
| assert k.startswith("models.") | |
| lang_pair = k.split(".")[1] | |
| if lang_pair not in self.models: | |
| del state_dict_subset[k] | |
| super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) | |
| def base_multilingual_architecture(args): | |
| base_architecture(args) | |
| args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) | |
| args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) | |
| args.share_encoders = getattr(args, "share_encoders", False) | |
| args.share_decoders = getattr(args, "share_decoders", False) | |
| def multilingual_transformer_iwslt_de_en(args): | |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) | |
| args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) | |
| args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) | |
| args.encoder_layers = getattr(args, "encoder_layers", 6) | |
| args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) | |
| args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) | |
| args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) | |
| args.decoder_layers = getattr(args, "decoder_layers", 6) | |
| base_multilingual_architecture(args) | |