| import torch | |
| import torch.nn as nn | |
| from typing import Optional, Union, Tuple | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from .wav2vec2_wrapper import Wav2VecWrapper | |
| from .multilevel_classifier import MultiLevelDownstreamModel | |
| class CustomModelForAudioClassification(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.output_hidden_states == True, "The upstream model must return all hidden states" | |
| self.config = config | |
| self.encoder = Wav2VecWrapper(config) | |
| self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True) | |
| def forward( | |
| self, | |
| input_features: Optional[torch.LongTensor], | |
| length: Optional[torch.LongTensor] = None, | |
| encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: | |
| if encoder_outputs is None: | |
| encoder_output = self.encoder( | |
| input_features, | |
| length=length, | |
| ) | |
| logits = self.classifier(**encoder_output) | |
| loss = None | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=encoder_output['encoder_hidden_states'] | |
| ) |