Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """ Bare wrapper of HF PyTorch T5 and Perceiver with the following modifications: | |
| - PerceiverTF encoder | |
| - ResConv pre-encoder | |
| - Projection layers for dynamic dimension matching | |
| - Sinusoidal absolute positional embeddings | |
| - Positional embeddings from Perceiver implementation | |
| - Task conditioning on encoder and decoder by input tokens | |
| """ | |
| import copy | |
| import warnings | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from torch.nn import CrossEntropyLoss | |
| from torch.utils.checkpoint import checkpoint | |
| from transformers.utils import logging | |
| from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.t5.modeling_t5 import (T5LayerNorm, T5Block, PARALLELIZE_DOCSTRING, DEPARALLELIZE_DOCSTRING, | |
| T5_START_DOCSTRING, T5_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, | |
| __HEAD_MASK_WARNING_MSG) | |
| from transformers.modeling_outputs import (Seq2SeqLMOutput, BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions) | |
| from transformers import T5Config #, T5PreTrainedModel | |
| from model.ops import FixedSinusoidalPositionalEmbedding | |
| # additional imports | |
| from model.t5mod import T5Stack | |
| from transformers.models.t5.modeling_t5 import (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5DenseActDense, | |
| T5DenseGatedActDense, T5Attention, load_tf_weights_in_t5, | |
| is_torch_fx_proxy) | |
| from transformers.utils import (DUMMY_INPUTS, DUMMY_MASK) | |
| logger = logging.get_logger(__name__) | |
| class T5PerceiverPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = None | |
| load_tf_weights = load_tf_weights_in_t5 | |
| base_model_prefix = "transformer" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["T5Block"] | |
| _keep_in_fp32_modules = ["wo"] | |
| def dummy_inputs(self): | |
| input_ids = torch.tensor(DUMMY_INPUTS) | |
| input_mask = torch.tensor(DUMMY_MASK) | |
| dummy_inputs = { | |
| "decoder_input_ids": input_ids, | |
| "input_ids": input_ids, | |
| "decoder_attention_mask": input_mask, | |
| } | |
| return dummy_inputs | |
| def _init_weights(self, module): | |
| """Initialize the weights""" | |
| factor = self.config.initializer_factor # Used for testing weights initialization | |
| if isinstance(module, T5LayerNorm): | |
| module.weight.data.fill_(factor * 1.0) | |
| elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): | |
| # Mesh TensorFlow embeddings initialization | |
| # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 | |
| module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) | |
| if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: | |
| module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) | |
| elif isinstance(module, T5DenseActDense): | |
| # Mesh TensorFlow FF initialization | |
| # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 | |
| # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 | |
| module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5)) | |
| if hasattr(module.wi, "bias") and module.wi.bias is not None: | |
| module.wi.bias.data.zero_() | |
| module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff)**-0.5)) | |
| if hasattr(module.wo, "bias") and module.wo.bias is not None: | |
| module.wo.bias.data.zero_() | |
| elif isinstance(module, T5DenseGatedActDense): | |
| module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5)) | |
| if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: | |
| module.wi_0.bias.data.zero_() | |
| module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5)) | |
| if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: | |
| module.wi_1.bias.data.zero_() | |
| module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff)**-0.5)) | |
| if hasattr(module.wo, "bias") and module.wo.bias is not None: | |
| module.wo.bias.data.zero_() | |
| elif isinstance(module, T5Attention): | |
| # Mesh TensorFlow attention initialization to avoid scaling before softmax | |
| # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 | |
| d_model = self.config.d_model | |
| key_value_proj_dim = self.config.d_kv | |
| n_heads = self.config.num_heads | |
| module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim)**-0.5)) | |
| module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) | |
| module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) | |
| module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim)**-0.5)) | |
| if module.has_relative_attention_bias: | |
| module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model)**-0.5)) | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if isinstance(module, (T5Attention, T5Stack)): | |
| module.gradient_checkpointing = value | |
| def _shift_right(self, input_ids): | |
| decoder_start_token_id = self.config.decoder_start_token_id | |
| pad_token_id = self.config.pad_token_id | |
| assert decoder_start_token_id is not None, ( | |
| "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." | |
| " See T5 docs for more information") | |
| # shift inputs to the right | |
| if is_torch_fx_proxy(input_ids): | |
| # Item assignment is not supported natively for proxies. | |
| shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) | |
| shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) | |
| else: | |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
| shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() | |
| shifted_input_ids[..., 0] = decoder_start_token_id | |
| assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." | |
| # replace possible -100 values in labels by `pad_token_id` | |
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) | |
| return shifted_input_ids | |
| class T5PerceiverForConditionalGeneration(T5PerceiverPreTrainedModel): | |
| config_class = None | |
| load_tf_weights = load_tf_weights_in_t5 | |
| base_model_prefix = "transformer" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["T5Block"] | |
| _keep_in_fp32_modules = ["wo"] | |
| def dummy_inputs(self): | |
| input_ids = torch.tensor(DUMMY_INPUTS) | |
| input_mask = torch.tensor(DUMMY_MASK) | |
| dummy_inputs = { | |
| "decoder_input_ids": input_ids, | |
| "input_ids": input_ids, | |
| "decoder_attention_mask": input_mask, | |
| } | |
| return dummy_inputs | |
| def __init__( | |
| self, | |
| model_cfg: dict, | |
| # config: T5Config, | |
| # use_fixed_absolute_pe: bool = True, | |
| # num_max_positions: int = 1025 | |
| ): | |
| super().__init__(config) | |
| self.model_dim = config.d_model | |
| """ mod: absolute position embedding """ | |
| self.use_fixed_absolute_pe = use_fixed_absolute_pe | |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) | |
| encoder_config = copy.deepcopy(config) | |
| encoder_config.is_decoder = False | |
| encoder_config.use_cache = False | |
| encoder_config.is_encoder_decoder = False | |
| self.encoder = T5Stack(encoder_config, | |
| self.shared, | |
| use_fixed_absolute_pe=use_fixed_absolute_pe, | |
| num_max_positions=num_max_positions) | |
| decoder_config = copy.deepcopy(config) | |
| decoder_config.is_decoder = True | |
| decoder_config.is_encoder_decoder = False | |
| decoder_config.num_layers = config.num_decoder_layers | |
| self.decoder = T5Stack(decoder_config, | |
| self.shared, | |
| use_fixed_absolute_pe=use_fixed_absolute_pe, | |
| num_max_positions=num_max_positions) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| # Model parallel | |
| self.model_parallel = False | |
| self.device_map = None | |
| def get_input_embeddings(self): | |
| return self.shared | |
| def set_input_embeddings(self, new_embeddings): | |
| self.shared = new_embeddings | |
| self.encoder.set_input_embeddings(new_embeddings) | |
| self.decoder.set_input_embeddings(new_embeddings) | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def get_encoder(self): | |
| return self.encoder | |
| def get_decoder(self): | |
| return self.decoder | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| decoder_head_mask: Optional[torch.FloatTensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., | |
| config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for | |
| labels in `[0, ..., config.vocab_size]` | |
| Returns: | |
| Examples: | |
| ```python | |
| >>> from transformers import AutoTokenizer, T5ForConditionalGeneration | |
| >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
| >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") | |
| >>> # training | |
| >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids | |
| >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids | |
| >>> outputs = model(input_ids=input_ids, labels=labels) | |
| >>> loss = outputs.loss | |
| >>> logits = outputs.logits | |
| >>> # inference | |
| >>> input_ids = tokenizer( | |
| ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" | |
| ... ).input_ids # Batch size 1 | |
| >>> outputs = model.generate(input_ids) | |
| >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
| >>> # studies have shown that owning a dog is good for you. | |
| ```""" | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask | |
| if head_mask is not None and decoder_head_mask is None: | |
| if self.config.num_layers == self.config.num_decoder_layers: | |
| warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) | |
| decoder_head_mask = head_mask | |
| # Encode if needed (training, first prediction pass) | |
| if encoder_outputs is None: | |
| # Convert encoder inputs in embeddings if needed | |
| encoder_outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| head_mask=head_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | |
| encoder_outputs = BaseModelOutput( | |
| last_hidden_state=encoder_outputs[0], | |
| hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, | |
| attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | |
| ) | |
| hidden_states = encoder_outputs[0] | |
| if self.model_parallel: | |
| torch.cuda.set_device(self.decoder.first_device) | |
| if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: | |
| # get decoder inputs from shifting lm labels to the right | |
| decoder_input_ids = self._shift_right(labels) | |
| # Set device for model parallelism | |
| if self.model_parallel: | |
| torch.cuda.set_device(self.decoder.first_device) | |
| hidden_states = hidden_states.to(self.decoder.first_device) | |
| if decoder_input_ids is not None: | |
| decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(self.decoder.first_device) | |
| if decoder_attention_mask is not None: | |
| decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) | |
| # Decode | |
| decoder_outputs = self.decoder( | |
| input_ids=decoder_input_ids, | |
| attention_mask=decoder_attention_mask, | |
| inputs_embeds=decoder_inputs_embeds, | |
| past_key_values=past_key_values, | |
| encoder_hidden_states=hidden_states, | |
| encoder_attention_mask=attention_mask, | |
| head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = decoder_outputs[0] | |
| # Set device for model parallelism | |
| if self.model_parallel: | |
| torch.cuda.set_device(self.encoder.first_device) | |
| self.lm_head = self.lm_head.to(self.encoder.first_device) | |
| sequence_output = sequence_output.to(self.lm_head.weight.device) | |
| if self.config.tie_word_embeddings: | |
| # Rescale output before projecting on vocab | |
| # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 | |
| sequence_output = sequence_output * (self.model_dim**-0.5) | |
| lm_logits = self.lm_head(sequence_output) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss(ignore_index=-100) | |
| # move labels to correct device to enable PP | |
| labels = labels.to(lm_logits.device) | |
| loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) | |
| # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 | |
| if not return_dict: | |
| output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs | |
| return ((loss,) + output) if loss is not None else output | |
| return Seq2SeqLMOutput( | |
| loss=loss, | |
| logits=lm_logits, | |
| past_key_values=decoder_outputs.past_key_values, | |
| decoder_hidden_states=decoder_outputs.hidden_states, | |
| decoder_attentions=decoder_outputs.attentions, | |
| cross_attentions=decoder_outputs.cross_attentions, | |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
| encoder_hidden_states=encoder_outputs.hidden_states, | |
| encoder_attentions=encoder_outputs.attentions, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| head_mask=None, | |
| decoder_head_mask=None, | |
| cross_attn_head_mask=None, | |
| use_cache=None, | |
| encoder_outputs=None, | |
| **kwargs, | |
| ): | |
| # cut decoder_input_ids if past is used | |
| if past_key_values is not None: | |
| input_ids = input_ids[:, -1:] | |
| return { | |
| "decoder_input_ids": input_ids, | |
| "past_key_values": past_key_values, | |
| "encoder_outputs": encoder_outputs, | |
| "attention_mask": attention_mask, | |
| "head_mask": head_mask, | |
| "decoder_head_mask": decoder_head_mask, | |
| "cross_attn_head_mask": cross_attn_head_mask, | |
| "use_cache": use_cache, | |
| } | |
| def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | |
| return self._shift_right(labels) | |
| def _reorder_cache(self, past_key_values, beam_idx): | |
| # if decoder past is not included in output | |
| # speedy decoding is disabled and no need to reorder | |
| if past_key_values is None: | |
| logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") | |
| return past_key_values | |
| reordered_decoder_past = () | |
| for layer_past_states in past_key_values: | |
| # get the correct batch idx from layer past batch dim | |
| # batch dim of `past` is at 2nd position | |
| reordered_layer_past_states = () | |
| for layer_past_state in layer_past_states: | |
| # need to set correct `past` for each of the four key / value states | |
| reordered_layer_past_states = reordered_layer_past_states + (layer_past_state.index_select( | |
| 0, beam_idx.to(layer_past_state.device)),) | |
| assert reordered_layer_past_states[0].shape == layer_past_states[0].shape | |
| assert len(reordered_layer_past_states) == len(layer_past_states) | |
| reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) | |
| return reordered_decoder_past | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| from transformers import AutoModel, AutoConfig | |
| class MyConfig(T5Config, PerceiverConfig): | |
| model_type = 'mymodel' | |
| def __init__(self, important_param=42, **kwargs): | |
| super().__init__(**kwargs) | |
| self.important_param = important_param | |