import torch import torch.nn as nn import torch.nn.functional as F import transformers from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, WhisperModel) from configs import VLFMConfig, LossFunction, LossConfig, build_tokenizer from projector import VLFMProjector from constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from typing import Optional, Tuple, List, Union class VLFMModel(transformers.LlamaPreTrainedModel): config_class = VLFMConfig def __init__(self, config, torch_dtype=torch.bfloat16): super(VLFMModel, self).__init__(config) whisper = WhisperModel.from_pretrained(config.audio_model_id, torch_dtype=torch_dtype,) self.encoder = whisper.encoder self.projector = VLFMProjector(config) self.language_model = AutoModelForCausalLM.from_pretrained(config.text_model_id, torch_dtype=torch_dtype) self._train_module(self.encoder, False) self._train_module(self.language_model, False) self._train_module(self.projector, True) self.encoder.to(dtype=torch_dtype) self.language_model.to(dtype=torch_dtype) self.projector.to(dtype=torch_dtype) self.tokenizer, self.audio_token_id = build_tokenizer(config.text_model_id, config.tokenizer_padding_side) self.tokenizer_model_max_length = self.tokenizer.model_max_length self._resize_token_embeddings(self.tokenizer) self.get_input_embeddings().to(dtype=self.language_model.dtype) if hasattr(self.language_model, "get_output_embeddings") and self.language_model.get_output_embeddings() is not None: self.language_model.get_output_embeddings().to(dtype=self.language_model.dtype) self.loss_config = LossConfig(LossFunction.KL_Divergence) #self.loss_config.loss_function = LossFunction.KL_Divergence self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, new_emb): return self.language_model.set_input_embeddings(new_emb) @property def embed_tokens(self): return self.language_model.get_input_embeddings() def _train_module(self, module, trainable: bool): for param in module.parameters(): param.requires_grad= trainable def _audio_iter(self, audio_batch_size): audio_index = 0 for i_b, count in enumerate(audio_batch_size.view(-1).tolist()): for _ in range(int(count)): yield i_b, audio_index audio_index += 1 def _resize_token_embeddings(self, tokenizer, pad_to_multiple_of=None): model_embeds = self.language_model.resize_token_embeddings(len(tokenizer)) self.config.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings return model_embeds def _encode_speech(self, audio_values): with torch.no_grad(): encoder_outputs = self.encoder(audio_values, output_hidden_states=False) audio_embeds = encoder_outputs.last_hidden_state downsampled_embeds = self.projector(audio_embeds) #(B, T, D) #print(f"Shape of projector output: {downsampled_embeds.shape}") return downsampled_embeds def _splice_chunks(self, text_embeds, audio_embeds, audio_token_start_idx, audio_token_len, audio_batch_size): D = text_embeds.size(-1) for i_b, i_chunk in self._audio_iter(audio_batch_size): start = int(audio_token_start_idx[i_chunk].item()) span = int(audio_token_len[i_chunk].item()) a = audio_embeds[i_chunk] Ta = a.size(0) use = min(Ta, span) text_embeds[i_b, start:start+use, :] = a[:use].to(text_embeds.dtype) def _compute_kl_loss( self, *, student_logits: torch.Tensor, labels: torch.Tensor, alt_input_ids: torch.Tensor, alt_attention_mask: torch.Tensor, alt_labels: torch.Tensor, past_key_values=None, **kwargs, ): lm_was_training = self.language_model.training self.language_model.eval() with torch.no_grad(): alt_input_embeds = self.language_model.get_input_embeddings()(alt_input_ids) teacher_out = self.language_model( inputs_embeds=alt_input_embeds, attention_mask=alt_attention_mask, use_cache=False, return_dict=True, past_key_values=past_key_values, ) teacher_logits = teacher_out.logits if lm_was_training: self.language_model.train() T = self.loss_config.kl_temperature student = F.log_softmax(student_logits[labels != IGNORE_INDEX] / T, dim=-1) teacher = F.softmax(teacher_logits[alt_labels != IGNORE_INDEX] / T, dim=-1) kl = F.kl_div(student, teacher, reduction="batchmean") return kl def forward( self, input_ids, attention_mask, labels=None, *, input_features=None, audio_token_start_idx = None, audio_token_len = None, audio_batch_size = None, alt_input_ids = None, alt_attention_mask = None, alt_labels = None, return_dict = True, **kwargs): tok = self.language_model.get_input_embeddings() text_embeds = tok(input_ids) if input_features is not None and audio_token_start_idx is not None: audio_embeds = self._encode_speech(input_features) self._splice_chunks( text_embeds, audio_embeds, audio_token_start_idx, audio_token_len, audio_batch_size ) out = self.language_model( inputs_embeds=text_embeds, attention_mask=attention_mask, labels =labels, return_dict=True, use_cache = True, ) logits = out.logits ce_loss = out.loss alpha = self.loss_config.ce_weight alpha = self.loss_config.ce_weight kl = None if ( self.training and alt_input_ids is not None and alt_attention_mask is not None and alt_labels is not None ): kl = self._compute_kl_loss( student_logits=logits, labels=labels, alt_input_ids=alt_input_ids, alt_attention_mask=alt_attention_mask, alt_labels=alt_labels, past_key_values=None, ) total_loss = alpha * ce_loss + (1 - alpha) * kl else: total_loss = ce_loss return { "loss": total_loss, "loss_ce": ce_loss.detach() if ce_loss is not None else None, "loss_kl": kl.detach() if kl is not None else None, "logits": logits,} ''' if ( self.training and self.loss_config.loss_function == LossFunction.KL_Divergence and alt_input_ids is not None and alt_attention_mask is not None and alt_labels is not None ): kl = self._compute_kl_loss( student_logits=logits, labels=labels, alt_input_ids=alt_input_ids, alt_attention_mask=alt_attention_mask, alt_labels=alt_labels, past_key_values=None,) return { "loss": kl, "loss_ce": (ce_loss.detach() if ce_loss is not None else None), logits: logits} if return_dict: return out return (ce_loss, logits) ''' def _prepare_inputs_embeds( self, input_ids, attention_mask, *, input_features = None, audio_token_start_idx = None, audio_token_len = None, audio_batch_size= None, ): """ Returns: inputs_embeds: [B, L, D] with audio spliced in attention_mask: [B, L] (unchanged) """ tok = self.language_model.get_input_embeddings() inputs_embeds = tok(input_ids) # [B, L, D] if input_features is not None and audio_token_start_idx is not None: # Normalize shapes: treat "one audio per sample" as N_chunks == B feats = input_features if feats.dim() == 3 and feats.size(0) == input_ids.size(0): audio_batch_size = torch.ones(input_ids.size(0), dtype=torch.long, device=input_ids.device) assert audio_batch_size is not None, "audio_batch_size required when splicing audio." # Encode + project, then splice audio_embeds = self._encode_audio(feats) # [N_chunks, T_audio, D] self._splice_chunks( text_embeds=inputs_embeds, audio_embeds=audio_embeds, audio_token_start_idx=audio_token_start_idx, audio_token_len=audio_token_len, audio_batch_size=audio_batch_size, ) return inputs_embeds, attention_mask @torch.no_grad() def generate( self, input_ids, # [B, L] attention_mask, # [B, L] *, input_features, audio_token_start_idx= None, audio_token_len= None, audio_batch_size = None, **gen_kwargs, ): """ Build spliced embeddings and call the base LM's generate""" self.eval() inputs_embeds, attn_mask = self._prepare_inputs_embeds( input_ids=input_ids, attention_mask=attention_mask, input_features=input_features, audio_token_start_idx=audio_token_start_idx, audio_token_len=audio_token_len, audio_batch_size=audio_batch_size, ) return self.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=attn_mask, **gen_kwargs, ) AutoConfig.register("babs-vlfm", VLFMConfig) AutoModel.register(VLFMConfig, VLFMModel)