|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import transformers |
|
|
from transformers import ( |
|
|
AutoConfig, AutoModel, |
|
|
AutoModelForCausalLM, WhisperModel) |
|
|
|
|
|
from configs import VLFMConfig, LossFunction, LossConfig, build_tokenizer |
|
|
from projector import VLFMProjector |
|
|
from constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX |
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.generation.utils import GenerateOutput |
|
|
from typing import Optional, Tuple, List, Union |
|
|
|
|
|
|
|
|
class VLFMModel(transformers.LlamaPreTrainedModel): |
|
|
config_class = VLFMConfig |
|
|
def __init__(self, config, torch_dtype=torch.bfloat16): |
|
|
super(VLFMModel, self).__init__(config) |
|
|
|
|
|
whisper = WhisperModel.from_pretrained(config.audio_model_id, |
|
|
torch_dtype=torch_dtype,) |
|
|
|
|
|
self.encoder = whisper.encoder |
|
|
self.projector = VLFMProjector(config) |
|
|
self.language_model = AutoModelForCausalLM.from_pretrained(config.text_model_id, |
|
|
torch_dtype=torch_dtype) |
|
|
|
|
|
self._train_module(self.encoder, False) |
|
|
self._train_module(self.language_model, False) |
|
|
self._train_module(self.projector, True) |
|
|
|
|
|
self.encoder.to(dtype=torch_dtype) |
|
|
self.language_model.to(dtype=torch_dtype) |
|
|
self.projector.to(dtype=torch_dtype) |
|
|
|
|
|
self.tokenizer, self.audio_token_id = build_tokenizer(config.text_model_id, config.tokenizer_padding_side) |
|
|
|
|
|
self.tokenizer_model_max_length = self.tokenizer.model_max_length |
|
|
self._resize_token_embeddings(self.tokenizer) |
|
|
self.get_input_embeddings().to(dtype=self.language_model.dtype) |
|
|
if hasattr(self.language_model, "get_output_embeddings") and self.language_model.get_output_embeddings() is not None: |
|
|
self.language_model.get_output_embeddings().to(dtype=self.language_model.dtype) |
|
|
|
|
|
self.loss_config = LossConfig(LossFunction.KL_Divergence) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.language_model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, new_emb): |
|
|
return self.language_model.set_input_embeddings(new_emb) |
|
|
|
|
|
@property |
|
|
def embed_tokens(self): |
|
|
return self.language_model.get_input_embeddings() |
|
|
|
|
|
def _train_module(self, module, trainable: bool): |
|
|
for param in module.parameters(): |
|
|
param.requires_grad= trainable |
|
|
|
|
|
def _audio_iter(self, audio_batch_size): |
|
|
audio_index = 0 |
|
|
for i_b, count in enumerate(audio_batch_size.view(-1).tolist()): |
|
|
for _ in range(int(count)): |
|
|
yield i_b, audio_index |
|
|
audio_index += 1 |
|
|
|
|
|
def _resize_token_embeddings(self, tokenizer, pad_to_multiple_of=None): |
|
|
|
|
|
model_embeds = self.language_model.resize_token_embeddings(len(tokenizer)) |
|
|
self.config.vocab_size = model_embeds.num_embeddings |
|
|
self.vocab_size = model_embeds.num_embeddings |
|
|
return model_embeds |
|
|
|
|
|
def _encode_speech(self, audio_values): |
|
|
with torch.no_grad(): |
|
|
encoder_outputs = self.encoder(audio_values, output_hidden_states=False) |
|
|
audio_embeds = encoder_outputs.last_hidden_state |
|
|
downsampled_embeds = self.projector(audio_embeds) |
|
|
|
|
|
return downsampled_embeds |
|
|
|
|
|
def _splice_chunks(self, text_embeds, audio_embeds, audio_token_start_idx, audio_token_len, audio_batch_size): |
|
|
D = text_embeds.size(-1) |
|
|
for i_b, i_chunk in self._audio_iter(audio_batch_size): |
|
|
start = int(audio_token_start_idx[i_chunk].item()) |
|
|
span = int(audio_token_len[i_chunk].item()) |
|
|
a = audio_embeds[i_chunk] |
|
|
Ta = a.size(0) |
|
|
use = min(Ta, span) |
|
|
text_embeds[i_b, start:start+use, :] = a[:use].to(text_embeds.dtype) |
|
|
|
|
|
|
|
|
def _compute_kl_loss( |
|
|
self, |
|
|
*, |
|
|
student_logits: torch.Tensor, |
|
|
labels: torch.Tensor, |
|
|
alt_input_ids: torch.Tensor, |
|
|
alt_attention_mask: torch.Tensor, |
|
|
alt_labels: torch.Tensor, |
|
|
past_key_values=None, |
|
|
**kwargs, |
|
|
): |
|
|
lm_was_training = self.language_model.training |
|
|
self.language_model.eval() |
|
|
with torch.no_grad(): |
|
|
alt_input_embeds = self.language_model.get_input_embeddings()(alt_input_ids) |
|
|
teacher_out = self.language_model( |
|
|
inputs_embeds=alt_input_embeds, |
|
|
attention_mask=alt_attention_mask, |
|
|
use_cache=False, |
|
|
return_dict=True, |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
teacher_logits = teacher_out.logits |
|
|
if lm_was_training: |
|
|
self.language_model.train() |
|
|
|
|
|
T = self.loss_config.kl_temperature |
|
|
student = F.log_softmax(student_logits[labels != IGNORE_INDEX] / T, dim=-1) |
|
|
teacher = F.softmax(teacher_logits[alt_labels != IGNORE_INDEX] / T, dim=-1) |
|
|
kl = F.kl_div(student, teacher, reduction="batchmean") |
|
|
return kl |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids, |
|
|
attention_mask, |
|
|
labels=None, |
|
|
*, |
|
|
input_features=None, |
|
|
audio_token_start_idx = None, |
|
|
audio_token_len = None, |
|
|
audio_batch_size = None, |
|
|
alt_input_ids = None, |
|
|
alt_attention_mask = None, |
|
|
alt_labels = None, |
|
|
return_dict = True, |
|
|
**kwargs): |
|
|
tok = self.language_model.get_input_embeddings() |
|
|
text_embeds = tok(input_ids) |
|
|
|
|
|
if input_features is not None and audio_token_start_idx is not None: |
|
|
audio_embeds = self._encode_speech(input_features) |
|
|
self._splice_chunks( |
|
|
text_embeds, |
|
|
audio_embeds, |
|
|
audio_token_start_idx, |
|
|
audio_token_len, |
|
|
audio_batch_size |
|
|
) |
|
|
|
|
|
out = self.language_model( |
|
|
inputs_embeds=text_embeds, |
|
|
attention_mask=attention_mask, |
|
|
labels =labels, |
|
|
return_dict=True, |
|
|
use_cache = True, |
|
|
) |
|
|
|
|
|
logits = out.logits |
|
|
ce_loss = out.loss |
|
|
|
|
|
alpha = self.loss_config.ce_weight |
|
|
alpha = self.loss_config.ce_weight |
|
|
|
|
|
kl = None |
|
|
if ( |
|
|
self.training |
|
|
and alt_input_ids is not None |
|
|
and alt_attention_mask is not None |
|
|
and alt_labels is not None |
|
|
): |
|
|
|
|
|
kl = self._compute_kl_loss( |
|
|
student_logits=logits, |
|
|
labels=labels, |
|
|
alt_input_ids=alt_input_ids, |
|
|
alt_attention_mask=alt_attention_mask, |
|
|
alt_labels=alt_labels, |
|
|
past_key_values=None, |
|
|
) |
|
|
|
|
|
total_loss = alpha * ce_loss + (1 - alpha) * kl |
|
|
else: |
|
|
total_loss = ce_loss |
|
|
|
|
|
return { |
|
|
"loss": total_loss, |
|
|
"loss_ce": ce_loss.detach() if ce_loss is not None else None, |
|
|
"loss_kl": kl.detach() if kl is not None else None, |
|
|
"logits": logits,} |
|
|
|
|
|
|
|
|
''' if ( |
|
|
self.training |
|
|
and self.loss_config.loss_function == LossFunction.KL_Divergence |
|
|
and alt_input_ids is not None |
|
|
and alt_attention_mask is not None |
|
|
and alt_labels is not None |
|
|
|
|
|
): |
|
|
kl = self._compute_kl_loss( |
|
|
student_logits=logits, |
|
|
labels=labels, |
|
|
alt_input_ids=alt_input_ids, |
|
|
alt_attention_mask=alt_attention_mask, |
|
|
alt_labels=alt_labels, |
|
|
past_key_values=None,) |
|
|
|
|
|
return { |
|
|
"loss": kl, |
|
|
"loss_ce": (ce_loss.detach() if ce_loss is not None else None), |
|
|
logits: logits} |
|
|
|
|
|
if return_dict: |
|
|
return out |
|
|
return (ce_loss, logits) ''' |
|
|
|
|
|
def _prepare_inputs_embeds( |
|
|
self, |
|
|
input_ids, |
|
|
attention_mask, |
|
|
*, |
|
|
input_features = None, |
|
|
audio_token_start_idx = None, |
|
|
audio_token_len = None, |
|
|
audio_batch_size= None, |
|
|
): |
|
|
""" |
|
|
Returns: |
|
|
inputs_embeds: [B, L, D] with audio spliced in |
|
|
attention_mask: [B, L] (unchanged) |
|
|
""" |
|
|
tok = self.language_model.get_input_embeddings() |
|
|
inputs_embeds = tok(input_ids) |
|
|
|
|
|
if input_features is not None and audio_token_start_idx is not None: |
|
|
|
|
|
feats = input_features |
|
|
if feats.dim() == 3 and feats.size(0) == input_ids.size(0): |
|
|
audio_batch_size = torch.ones(input_ids.size(0), dtype=torch.long, device=input_ids.device) |
|
|
assert audio_batch_size is not None, "audio_batch_size required when splicing audio." |
|
|
|
|
|
|
|
|
audio_embeds = self._encode_audio(feats) |
|
|
self._splice_chunks( |
|
|
text_embeds=inputs_embeds, |
|
|
audio_embeds=audio_embeds, |
|
|
audio_token_start_idx=audio_token_start_idx, |
|
|
audio_token_len=audio_token_len, |
|
|
audio_batch_size=audio_batch_size, |
|
|
) |
|
|
|
|
|
return inputs_embeds, attention_mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids, |
|
|
attention_mask, |
|
|
*, |
|
|
input_features, |
|
|
audio_token_start_idx= None, |
|
|
audio_token_len= None, |
|
|
audio_batch_size = None, |
|
|
**gen_kwargs, |
|
|
): |
|
|
""" |
|
|
Build spliced embeddings and call the base LM's generate""" |
|
|
self.eval() |
|
|
inputs_embeds, attn_mask = self._prepare_inputs_embeds( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
input_features=input_features, |
|
|
audio_token_start_idx=audio_token_start_idx, |
|
|
audio_token_len=audio_token_len, |
|
|
audio_batch_size=audio_batch_size, |
|
|
) |
|
|
return self.language_model.generate( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attn_mask, |
|
|
**gen_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
AutoConfig.register("babs-vlfm", VLFMConfig) |
|
|
AutoModel.register(VLFMConfig, VLFMModel) |
|
|
|