Spaces:
Build error
Build error
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import Wav2Vec2Config | |
| from .torch_utils import get_mask_from_lengths | |
| from .wav2vec2 import Wav2Vec2Model | |
| class Audio2MeshModel(nn.Module): | |
| def __init__( | |
| self, | |
| config | |
| ): | |
| super().__init__() | |
| out_dim = config['out_dim'] | |
| latent_dim = config['latent_dim'] | |
| model_path = config['model_path'] | |
| only_last_fetures = config['only_last_fetures'] | |
| from_pretrained = config['from_pretrained'] | |
| self._only_last_features = only_last_fetures | |
| self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) | |
| if from_pretrained: | |
| self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) | |
| else: | |
| self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) | |
| self.audio_encoder.feature_extractor._freeze_parameters() | |
| hidden_size = self.audio_encoder_config.hidden_size | |
| self.in_fn = nn.Linear(hidden_size, latent_dim) | |
| self.out_fn = nn.Linear(latent_dim, out_dim) | |
| nn.init.constant_(self.out_fn.weight, 0) | |
| nn.init.constant_(self.out_fn.bias, 0) | |
| def forward(self, audio, label, audio_len=None): | |
| attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None | |
| seq_len = label.shape[1] | |
| embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True, | |
| attention_mask=attention_mask) | |
| if self._only_last_features: | |
| hidden_states = embeddings.last_hidden_state | |
| else: | |
| hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) | |
| layer_in = self.in_fn(hidden_states) | |
| out = self.out_fn(layer_in) | |
| return out, None | |
| def infer(self, input_value, seq_len): | |
| embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) | |
| if self._only_last_features: | |
| hidden_states = embeddings.last_hidden_state | |
| else: | |
| hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) | |
| layer_in = self.in_fn(hidden_states) | |
| out = self.out_fn(layer_in) | |
| return out | |