STAR / fairseq /models /wav2vec /wav2vec2_classification.py
Yixuan Li
add fairseq folder
85ba398
raw
history blame
12.1 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import II, MISSING, open_dict
from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, Wav2Vec2Config
from fairseq.models.wav2vec.wav2vec2_asr import Embedding, Linear, Wav2VecEncoder, Wav2Vec2AsrConfig
from fairseq.tasks import FairseqTask
logging.basicConfig(level=logging.DEBUG)
@dataclass
class Wav2Vec2ClassificationConfig(Wav2Vec2AsrConfig):
latent_embed_dim: Optional[int] = field(
default=None, metadata={"help": "latent dim (encoder w2v -> latent -> class"}
)
pooling: str = field(
default="first_token",
metadata={"help": "pooling layer choices"},
)
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="gelu", metadata={"help": "activation function to use"}
)
@register_model("wav2vec_classification", dataclass=Wav2Vec2ClassificationConfig)
class Wav2VecClassification(BaseFairseqModel):
# TODO: Can be shared/merged with ASR model class as w2v_encoder params are common.
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
w2v_encoder: BaseFairseqModel,
pooling_layer,
):
super().__init__()
self.cfg = cfg
self.w2v_encoder = w2v_encoder
self.pooling_layer = pooling_layer
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: Wav2Vec2ClassificationConfig, task: FairseqTask):
"""Build a new model instance."""
w2v_encoder = Wav2VecEncoder(cfg, None)
pooling_layer = get_pooling_layer(
cfg,
w2v_encoder.w2v_model.encoder.layers[-1].embedding_dim,
len(task.target_dictionary),
len(w2v_encoder.w2v_model.encoder.layers),
)
return cls(cfg, w2v_encoder, pooling_layer)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def get_logits(self, net_output):
return net_output
def forward(self, **kwargs):
encoder_out_dict = self.w2v_encoder(**kwargs)
w2v_encoder_out = encoder_out_dict["encoder_out"] # TxBxC
w2v_encoder_padding_mask = encoder_out_dict["padding_mask"] # BxT
# w2v_encoder_layer_results = encoder_out_dict["layer_results"]
return self.pooling_layer(
last_layer_feats=w2v_encoder_out,
padding_mask=w2v_encoder_padding_mask,
# all_layer_feats=w2v_encoder_layer_results,
)
# def forward_latent(self, **kwargs):
# encoder_out_dict = self.w2v_encoder(**kwargs)
# w2v_encoder_out = encoder_out_dict["encoder_out"]
# w2v_encoder_padding_mask = encoder_out_dict["encoder_padding_mask"]
# w2v_encoder_layer_results = encoder_out_dict["layer_results"]
# return self.pooling_layer.forward_latent(
# last_layer_feats=w2v_encoder_out,
# padding_mask=w2v_encoder_padding_mask,
# all_layer_feats=w2v_encoder_layer_results,
# )
def get_pooling_layer(
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
encoder_layers: int,
):
assert cfg.pooling == 'mean'
if cfg.pooling == "first_token":
return FirstToken(cfg, encoder_embed_dim, num_targets)
# elif cfg.pooling == "mean":
# return MeanPooling(cfg, encoder_embed_dim, num_targets)
elif cfg.pooling == "mean":
return MeanPoolingFast(cfg, encoder_embed_dim, num_targets)
elif cfg.pooling == "mean_amsoftmax":
return MeanPoolingFastAMSoftmax(cfg, encoder_embed_dim, num_targets)
elif cfg.pooling == "max":
return MaxPoolingFast(cfg, encoder_embed_dim, num_targets)
elif cfg.pooling == "elmo":
return LayerWeightedMeanPooling(
cfg, encoder_embed_dim, num_targets, encoder_layers
)
else:
raise NotImplementedError(f"{cfg.pooling} has not been implemented yet.")
class Pooling(nn.Module):
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
):
super().__init__()
self.projection = Linear(encoder_embed_dim, num_targets)
def forward(self, last_layer_feats, **kwargs):
raise NotImplementedError()
class FirstToken(Pooling):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, last_layer_feats, **kwargs):
return self.projection(last_layer_feats[:, 0])
# class MeanPooling(Pooling):
# def __init__(
# self,
# cfg: Wav2VecClassificationConfig,
# encoder_embed_dim: int,
# num_targets: int,
# **kwargs,
# ):
# super().__init__(cfg, encoder_embed_dim, num_targets)
# self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
# self.linear = Linear(encoder_embed_dim, encoder_embed_dim)
# def forward(self, last_layer_feats, padding_mask, **kwargs):
# # last_layer_feats: [BxTxD]
# # padding_mask: [BxT]
# last_layer_feats = self.linear(self.activation_fn(last_layer_feats))
# input_lengths = (1 - padding_mask.long()).sum(-1)
# pooled_feature_list = []
# for i in range(len(last_layer_feats)):
# length = input_lengths[i]
# pooled_feature = torch.mean(last_layer_feats[i][:length], dim=0)
# pooled_feature_list.append(pooled_feature)
# return self.projection(torch.stack(pooled_feature_list))
def fn_mean(x, mask):
"""
Args:
x: TxBxD
mask: BxT
Return:
y: BxD
"""
if mask is not None:
mask = mask.t()[:, :, None]
return (x * mask).sum(0) / mask.sum(0)
else:
return x.sum(0) / x.shape[0]
class MeanPoolingFast(nn.Module):
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
**kwargs,
):
super().__init__()
self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
self.latent_embed_dim = (
cfg.latent_embed_dim
if cfg.latent_embed_dim is not None
else encoder_embed_dim
)
logging.debug(f"| {self.latent_embed_dim=}")
self.linear = Linear(encoder_embed_dim, self.latent_embed_dim)
self.projection = Linear(self.latent_embed_dim, num_targets)
def forward(self, last_layer_feats, padding_mask, **kwargs):
"""
Arguments
features - [TxBxD] Acoustic feature with shape
padding_mask - [BxT] Padding Mask
"""
if padding_mask is not None:
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
else:
feat_mask = None
feat = self.linear(last_layer_feats)
feat = fn_mean(feat, feat_mask)
feat = self.activation_fn(feat)
return self.projection(feat)
def forward_latent(self, last_layer_feats, padding_mask, **kwargs):
"""
Arguments
features - [TxBxD] Acoustic feature with shape
padding_mask - [BxT] Padding Mask
"""
if padding_mask is not None:
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
else:
feat_mask = None
feat = self.linear(last_layer_feats)
feat = fn_mean(feat, feat_mask)
return feat
class MeanPoolingFastAMSoftmax(MeanPoolingFast):
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
**kwargs,
):
super().__init__(cfg, encoder_embed_dim, num_targets, **kwargs)
self.projection = Linear(self.latent_embed_dim, num_targets, bias=False)
nn.init.xavier_normal_(self.projection.weight, gain=1)
def forward(self, last_layer_feats, padding_mask, **kwargs):
"""
Arguments
features - [BxTxD] Acoustic feature with shape
padding_mask - [BxT] Padding Mask
"""
feat_mask = (~padding_mask).to(last_layer_feats.dtype) # T,B -> B,T
feat = self.linear(last_layer_feats) # B,T,D
feat = fn_mean(feat, feat_mask) # B,D
feat = self.activation_fn(feat)
# normalize feat
feat_norm = F.normalize(feat, p=2, dim=-1) # B,D
weight_norm = F.normalize(self.projection.weight.t(), p=2, dim=-1) # D,K
cos_fw = feat_norm @ weight_norm
return cos_fw
def fn_max(x, mask):
"""
Args:
x: TxBxD
mask: BxT
Return:
y: BxD
"""
mask = mask.t()[:, :, None].to(torch.bool)
return x.masked_fill(~mask, -1e-8).max(0)[0]
class MaxPoolingFast(Pooling):
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
**kwargs,
):
super().__init__(cfg, encoder_embed_dim, num_targets)
self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
self.linear = Linear(encoder_embed_dim, encoder_embed_dim)
def forward(self, last_layer_feats, padding_mask, **kwargs):
"""
Arguments
features - [TxBxD] Acoustic feature with shape
padding_mask - [BxT] Padding Mask
"""
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
feat = self.linear(last_layer_feats)
feat = fn_max(feat, feat_mask)
feat = self.activation_fn(feat)
return self.projection(feat)
class LayerWeightedMeanPooling(MeanPoolingFast):
"""Elmo-style weighted average representation."""
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
encoder_layers: int,
):
super().__init__(cfg, encoder_embed_dim, num_targets)
self.num_layers = encoder_layers
self.weights = nn.Parameter(torch.ones(encoder_layers))
def forward(self, last_layer_feats, padding_mask, all_layer_feats):
# last_layer_feats: [BxTxD]
# padding_mask: [BxT]
if not self.training:
msg = (
f"Number of layers in input features = {len(all_layer_feats)}."
f" Expected {self.num_layers} layers."
)
assert len(all_layer_feats) == self.num_layers, msg
# Stack up all layers and reshape to (num_layers, features)
all_layer_feats_stacked = torch.stack(all_layer_feats, dim=0)
num_layers, *original_feat_shape = all_layer_feats_stacked.shape
all_layer_feats_stacked_flat = all_layer_feats_stacked.view(num_layers, -1)
# Weighted average
normalized_weights = F.softmax(self.weights, dim=-1)
weighted_avg_features = (
normalized_weights.unsqueeze(-1) * all_layer_feats_stacked_flat
).sum(dim=0)
weighted_avg_features = weighted_avg_features.view(*original_feat_shape)
# Mean Pooling on weighted average features.
return super().forward(weighted_avg_features, padding_mask)