Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import contextlib | |
| import logging | |
| from argparse import Namespace | |
| from dataclasses import dataclass, field | |
| from typing import Any, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from omegaconf import II, MISSING, open_dict | |
| from fairseq import checkpoint_utils, tasks, utils | |
| from fairseq.dataclass import ChoiceEnum, FairseqDataclass | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model | |
| from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, Wav2Vec2Config | |
| from fairseq.models.wav2vec.wav2vec2_asr import Embedding, Linear, Wav2VecEncoder, Wav2Vec2AsrConfig | |
| from fairseq.tasks import FairseqTask | |
| logging.basicConfig(level=logging.DEBUG) | |
| class Wav2Vec2ClassificationConfig(Wav2Vec2AsrConfig): | |
| latent_embed_dim: Optional[int] = field( | |
| default=None, metadata={"help": "latent dim (encoder w2v -> latent -> class"} | |
| ) | |
| pooling: str = field( | |
| default="first_token", | |
| metadata={"help": "pooling layer choices"}, | |
| ) | |
| activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( | |
| default="gelu", metadata={"help": "activation function to use"} | |
| ) | |
| class Wav2VecClassification(BaseFairseqModel): | |
| # TODO: Can be shared/merged with ASR model class as w2v_encoder params are common. | |
| def __init__( | |
| self, | |
| cfg: Wav2Vec2ClassificationConfig, | |
| w2v_encoder: BaseFairseqModel, | |
| pooling_layer, | |
| ): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.w2v_encoder = w2v_encoder | |
| self.pooling_layer = pooling_layer | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| super().upgrade_state_dict_named(state_dict, name) | |
| return state_dict | |
| def build_model(cls, cfg: Wav2Vec2ClassificationConfig, task: FairseqTask): | |
| """Build a new model instance.""" | |
| w2v_encoder = Wav2VecEncoder(cfg, None) | |
| pooling_layer = get_pooling_layer( | |
| cfg, | |
| w2v_encoder.w2v_model.encoder.layers[-1].embedding_dim, | |
| len(task.target_dictionary), | |
| len(w2v_encoder.w2v_model.encoder.layers), | |
| ) | |
| return cls(cfg, w2v_encoder, pooling_layer) | |
| def get_normalized_probs(self, net_output, log_probs): | |
| """Get normalized probabilities (or log probs) from a net's output.""" | |
| logits = net_output | |
| if log_probs: | |
| return utils.log_softmax(logits.float(), dim=-1) | |
| else: | |
| return utils.softmax(logits.float(), dim=-1) | |
| def get_logits(self, net_output): | |
| return net_output | |
| def forward(self, **kwargs): | |
| encoder_out_dict = self.w2v_encoder(**kwargs) | |
| w2v_encoder_out = encoder_out_dict["encoder_out"] # TxBxC | |
| w2v_encoder_padding_mask = encoder_out_dict["padding_mask"] # BxT | |
| # w2v_encoder_layer_results = encoder_out_dict["layer_results"] | |
| return self.pooling_layer( | |
| last_layer_feats=w2v_encoder_out, | |
| padding_mask=w2v_encoder_padding_mask, | |
| # all_layer_feats=w2v_encoder_layer_results, | |
| ) | |
| # def forward_latent(self, **kwargs): | |
| # encoder_out_dict = self.w2v_encoder(**kwargs) | |
| # w2v_encoder_out = encoder_out_dict["encoder_out"] | |
| # w2v_encoder_padding_mask = encoder_out_dict["encoder_padding_mask"] | |
| # w2v_encoder_layer_results = encoder_out_dict["layer_results"] | |
| # return self.pooling_layer.forward_latent( | |
| # last_layer_feats=w2v_encoder_out, | |
| # padding_mask=w2v_encoder_padding_mask, | |
| # all_layer_feats=w2v_encoder_layer_results, | |
| # ) | |
| def get_pooling_layer( | |
| cfg: Wav2Vec2ClassificationConfig, | |
| encoder_embed_dim: int, | |
| num_targets: int, | |
| encoder_layers: int, | |
| ): | |
| assert cfg.pooling == 'mean' | |
| if cfg.pooling == "first_token": | |
| return FirstToken(cfg, encoder_embed_dim, num_targets) | |
| # elif cfg.pooling == "mean": | |
| # return MeanPooling(cfg, encoder_embed_dim, num_targets) | |
| elif cfg.pooling == "mean": | |
| return MeanPoolingFast(cfg, encoder_embed_dim, num_targets) | |
| elif cfg.pooling == "mean_amsoftmax": | |
| return MeanPoolingFastAMSoftmax(cfg, encoder_embed_dim, num_targets) | |
| elif cfg.pooling == "max": | |
| return MaxPoolingFast(cfg, encoder_embed_dim, num_targets) | |
| elif cfg.pooling == "elmo": | |
| return LayerWeightedMeanPooling( | |
| cfg, encoder_embed_dim, num_targets, encoder_layers | |
| ) | |
| else: | |
| raise NotImplementedError(f"{cfg.pooling} has not been implemented yet.") | |
| class Pooling(nn.Module): | |
| def __init__( | |
| self, | |
| cfg: Wav2Vec2ClassificationConfig, | |
| encoder_embed_dim: int, | |
| num_targets: int, | |
| ): | |
| super().__init__() | |
| self.projection = Linear(encoder_embed_dim, num_targets) | |
| def forward(self, last_layer_feats, **kwargs): | |
| raise NotImplementedError() | |
| class FirstToken(Pooling): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, last_layer_feats, **kwargs): | |
| return self.projection(last_layer_feats[:, 0]) | |
| # class MeanPooling(Pooling): | |
| # def __init__( | |
| # self, | |
| # cfg: Wav2VecClassificationConfig, | |
| # encoder_embed_dim: int, | |
| # num_targets: int, | |
| # **kwargs, | |
| # ): | |
| # super().__init__(cfg, encoder_embed_dim, num_targets) | |
| # self.activation_fn = utils.get_activation_fn(cfg.activation_fn) | |
| # self.linear = Linear(encoder_embed_dim, encoder_embed_dim) | |
| # def forward(self, last_layer_feats, padding_mask, **kwargs): | |
| # # last_layer_feats: [BxTxD] | |
| # # padding_mask: [BxT] | |
| # last_layer_feats = self.linear(self.activation_fn(last_layer_feats)) | |
| # input_lengths = (1 - padding_mask.long()).sum(-1) | |
| # pooled_feature_list = [] | |
| # for i in range(len(last_layer_feats)): | |
| # length = input_lengths[i] | |
| # pooled_feature = torch.mean(last_layer_feats[i][:length], dim=0) | |
| # pooled_feature_list.append(pooled_feature) | |
| # return self.projection(torch.stack(pooled_feature_list)) | |
| def fn_mean(x, mask): | |
| """ | |
| Args: | |
| x: TxBxD | |
| mask: BxT | |
| Return: | |
| y: BxD | |
| """ | |
| if mask is not None: | |
| mask = mask.t()[:, :, None] | |
| return (x * mask).sum(0) / mask.sum(0) | |
| else: | |
| return x.sum(0) / x.shape[0] | |
| class MeanPoolingFast(nn.Module): | |
| def __init__( | |
| self, | |
| cfg: Wav2Vec2ClassificationConfig, | |
| encoder_embed_dim: int, | |
| num_targets: int, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.activation_fn = utils.get_activation_fn(cfg.activation_fn) | |
| self.latent_embed_dim = ( | |
| cfg.latent_embed_dim | |
| if cfg.latent_embed_dim is not None | |
| else encoder_embed_dim | |
| ) | |
| logging.debug(f"| {self.latent_embed_dim=}") | |
| self.linear = Linear(encoder_embed_dim, self.latent_embed_dim) | |
| self.projection = Linear(self.latent_embed_dim, num_targets) | |
| def forward(self, last_layer_feats, padding_mask, **kwargs): | |
| """ | |
| Arguments | |
| features - [TxBxD] Acoustic feature with shape | |
| padding_mask - [BxT] Padding Mask | |
| """ | |
| if padding_mask is not None: | |
| feat_mask = (~padding_mask).to(last_layer_feats.dtype) | |
| else: | |
| feat_mask = None | |
| feat = self.linear(last_layer_feats) | |
| feat = fn_mean(feat, feat_mask) | |
| feat = self.activation_fn(feat) | |
| return self.projection(feat) | |
| def forward_latent(self, last_layer_feats, padding_mask, **kwargs): | |
| """ | |
| Arguments | |
| features - [TxBxD] Acoustic feature with shape | |
| padding_mask - [BxT] Padding Mask | |
| """ | |
| if padding_mask is not None: | |
| feat_mask = (~padding_mask).to(last_layer_feats.dtype) | |
| else: | |
| feat_mask = None | |
| feat = self.linear(last_layer_feats) | |
| feat = fn_mean(feat, feat_mask) | |
| return feat | |
| class MeanPoolingFastAMSoftmax(MeanPoolingFast): | |
| def __init__( | |
| self, | |
| cfg: Wav2Vec2ClassificationConfig, | |
| encoder_embed_dim: int, | |
| num_targets: int, | |
| **kwargs, | |
| ): | |
| super().__init__(cfg, encoder_embed_dim, num_targets, **kwargs) | |
| self.projection = Linear(self.latent_embed_dim, num_targets, bias=False) | |
| nn.init.xavier_normal_(self.projection.weight, gain=1) | |
| def forward(self, last_layer_feats, padding_mask, **kwargs): | |
| """ | |
| Arguments | |
| features - [BxTxD] Acoustic feature with shape | |
| padding_mask - [BxT] Padding Mask | |
| """ | |
| feat_mask = (~padding_mask).to(last_layer_feats.dtype) # T,B -> B,T | |
| feat = self.linear(last_layer_feats) # B,T,D | |
| feat = fn_mean(feat, feat_mask) # B,D | |
| feat = self.activation_fn(feat) | |
| # normalize feat | |
| feat_norm = F.normalize(feat, p=2, dim=-1) # B,D | |
| weight_norm = F.normalize(self.projection.weight.t(), p=2, dim=-1) # D,K | |
| cos_fw = feat_norm @ weight_norm | |
| return cos_fw | |
| def fn_max(x, mask): | |
| """ | |
| Args: | |
| x: TxBxD | |
| mask: BxT | |
| Return: | |
| y: BxD | |
| """ | |
| mask = mask.t()[:, :, None].to(torch.bool) | |
| return x.masked_fill(~mask, -1e-8).max(0)[0] | |
| class MaxPoolingFast(Pooling): | |
| def __init__( | |
| self, | |
| cfg: Wav2Vec2ClassificationConfig, | |
| encoder_embed_dim: int, | |
| num_targets: int, | |
| **kwargs, | |
| ): | |
| super().__init__(cfg, encoder_embed_dim, num_targets) | |
| self.activation_fn = utils.get_activation_fn(cfg.activation_fn) | |
| self.linear = Linear(encoder_embed_dim, encoder_embed_dim) | |
| def forward(self, last_layer_feats, padding_mask, **kwargs): | |
| """ | |
| Arguments | |
| features - [TxBxD] Acoustic feature with shape | |
| padding_mask - [BxT] Padding Mask | |
| """ | |
| feat_mask = (~padding_mask).to(last_layer_feats.dtype) | |
| feat = self.linear(last_layer_feats) | |
| feat = fn_max(feat, feat_mask) | |
| feat = self.activation_fn(feat) | |
| return self.projection(feat) | |
| class LayerWeightedMeanPooling(MeanPoolingFast): | |
| """Elmo-style weighted average representation.""" | |
| def __init__( | |
| self, | |
| cfg: Wav2Vec2ClassificationConfig, | |
| encoder_embed_dim: int, | |
| num_targets: int, | |
| encoder_layers: int, | |
| ): | |
| super().__init__(cfg, encoder_embed_dim, num_targets) | |
| self.num_layers = encoder_layers | |
| self.weights = nn.Parameter(torch.ones(encoder_layers)) | |
| def forward(self, last_layer_feats, padding_mask, all_layer_feats): | |
| # last_layer_feats: [BxTxD] | |
| # padding_mask: [BxT] | |
| if not self.training: | |
| msg = ( | |
| f"Number of layers in input features = {len(all_layer_feats)}." | |
| f" Expected {self.num_layers} layers." | |
| ) | |
| assert len(all_layer_feats) == self.num_layers, msg | |
| # Stack up all layers and reshape to (num_layers, features) | |
| all_layer_feats_stacked = torch.stack(all_layer_feats, dim=0) | |
| num_layers, *original_feat_shape = all_layer_feats_stacked.shape | |
| all_layer_feats_stacked_flat = all_layer_feats_stacked.view(num_layers, -1) | |
| # Weighted average | |
| normalized_weights = F.softmax(self.weights, dim=-1) | |
| weighted_avg_features = ( | |
| normalized_weights.unsqueeze(-1) * all_layer_feats_stacked_flat | |
| ).sum(dim=0) | |
| weighted_avg_features = weighted_avg_features.view(*original_feat_shape) | |
| # Mean Pooling on weighted average features. | |
| return super().forward(weighted_avg_features, padding_mask) |