STAR / models /content_adapter.py
Yixuan Li
first commit
4853fdc
raw
history blame
13.9 kB
import math
import torch
import torch.nn as nn
from utils.torch_utilities import concat_non_padding, restore_from_concat
######################
# fastspeech modules
######################
class LayerNorm(nn.LayerNorm):
"""Layer normalization module.
:param int nout: output dim size
:param int dim: dimension to be normalized
"""
def __init__(self, nout, dim=-1):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim
def forward(self, x):
"""Apply layer normalization.
:param torch.Tensor x: input tensor
:return: layer normalized tensor
:rtype torch.Tensor
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return super(LayerNorm,
self).forward(x.transpose(1, -1)).transpose(1, -1)
class DurationPredictor(nn.Module):
def __init__(
self,
in_channels: int,
filter_channels: int,
n_layers: int = 2,
kernel_size: int = 3,
p_dropout: float = 0.1,
padding: str = "SAME"
):
super(DurationPredictor, self).__init__()
self.conv = nn.ModuleList()
self.kernel_size = kernel_size
self.padding = padding
for idx in range(n_layers):
in_chans = in_channels if idx == 0 else filter_channels
self.conv += [
nn.Sequential(
nn.ConstantPad1d(((kernel_size - 1) // 2,
(kernel_size - 1) //
2) if padding == 'SAME' else
(kernel_size - 1, 0), 0),
nn.Conv1d(
in_chans,
filter_channels,
kernel_size,
stride=1,
padding=0
), nn.ReLU(), LayerNorm(filter_channels, dim=1),
nn.Dropout(p_dropout)
)
]
self.linear = nn.Linear(filter_channels, 1)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
# x: [B, T, E]
x = x.transpose(1, -1)
x_mask = x_mask.unsqueeze(1).to(x.device)
for f in self.conv:
x = f(x)
x = x * x_mask.float()
x = self.linear(x.transpose(1, -1)
) * x_mask.transpose(1, -1).float() # [B, T, 1]
return x
######################
# adapter modules
######################
class ContentAdapterBase(nn.Module):
def __init__(self, d_out):
super().__init__()
self.d_out = d_out
class SinusoidalPositionalEmbedding(nn.Module):
def __init__(self, d_model, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(1), :]
return self.dropout(x)
class ContentAdapter(ContentAdapterBase):
def __init__(
self,
d_model: int,
d_out: int,
num_layers: int,
num_heads: int,
duration_predictor: DurationPredictor,
dropout: float = 0.1,
norm_first: bool = False,
activation: str = "gelu",
duration_grad_scale: float = 0.0,
):
super().__init__(d_out)
self.duration_grad_scale = duration_grad_scale
self.cls_embed = nn.Parameter(torch.randn(d_model))
if hasattr(torch, "npu") and torch.npu.is_available():
enable_nested_tensor = False
else:
enable_nested_tensor = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=4 * d_model,
dropout=dropout,
activation=activation,
norm_first=norm_first,
batch_first=True
)
self.encoder_layers = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_layers,
enable_nested_tensor=enable_nested_tensor
)
self.duration_predictor = duration_predictor
self.content_proj = nn.Conv1d(d_model, d_out, 1)
def forward(self, x, x_mask):
batch_size = x.size(0)
cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
cls_embed = cls_embed.to(x.device).unsqueeze(1)
x = torch.cat([cls_embed, x], dim=1)
cls_mask = torch.ones(batch_size, 1).to(x_mask.device)
x_mask = torch.cat([cls_mask, x_mask], dim=1)
x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool())
x_grad_rescaled = x * self.duration_grad_scale + x.detach(
) * (1 - self.duration_grad_scale)
duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
class PrefixAdapter(ContentAdapterBase):
def __init__(
self,
content_dim: int,
d_model: int,
d_out: int,
prefix_dim: int,
num_layers: int,
num_heads: int,
duration_predictor: DurationPredictor,
dropout: float = 0.1,
norm_first: bool = False,
use_last_norm: bool = True,
activation: str = "gelu",
duration_grad_scale: float = 0.1,
):
super().__init__(d_out)
self.duration_grad_scale = duration_grad_scale
self.prefix_mlp = nn.Sequential(
nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(d_model, d_model)
)
self.content_mlp = nn.Sequential(
nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(d_model, d_model)
)
layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=4 * d_model,
dropout=dropout,
activation=activation,
batch_first=True,
norm_first=norm_first
)
if hasattr(torch, "npu") and torch.npu.is_available():
enable_nested_tensor = False
else:
enable_nested_tensor = True
self.cls_embed = nn.Parameter(torch.randn(d_model))
# self.pos_embed = SinusoidalPositionalEmbedding(d_model, dropout)
self.layers = nn.TransformerEncoder(
encoder_layer=layer,
num_layers=num_layers,
enable_nested_tensor=enable_nested_tensor
)
self.use_last_norm = use_last_norm
if self.use_last_norm:
self.last_norm = nn.LayerNorm(d_model)
self.duration_predictor = duration_predictor
self.content_proj = nn.Conv1d(d_model, d_out, 1)
nn.init.normal_(self.cls_embed, 0., 0.02)
nn.init.xavier_uniform_(self.content_proj.weight)
nn.init.constant_(self.content_proj.bias, 0.)
def forward(self, content, content_mask, instruction, instruction_mask):
batch_size = content.size(0)
cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
cls_embed = cls_embed.to(content.device).unsqueeze(1)
content = self.content_mlp(content)
x = torch.cat([cls_embed, content], dim=1)
cls_mask = torch.ones(batch_size, 1,
dtype=bool).to(content_mask.device)
x_mask = torch.cat([cls_mask, content_mask], dim=1)
prefix = self.prefix_mlp(instruction)
seq, seq_mask, perm = concat_non_padding(
prefix, instruction_mask, x, x_mask
)
# seq = self.pos_embed(seq)
x = self.layers(seq, src_key_padding_mask=~seq_mask.bool())
if self.use_last_norm:
x = self.last_norm(x)
_, x = restore_from_concat(x, instruction_mask, x_mask, perm)
x_grad_rescaled = x * self.duration_grad_scale + x.detach(
) * (1 - self.duration_grad_scale)
duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
class CrossAttentionAdapter(ContentAdapterBase):
def __init__(
self,
d_out: int,
content_dim: int,
prefix_dim: int,
num_heads: int,
duration_predictor: DurationPredictor,
dropout: float = 0.1,
duration_grad_scale: float = 0.1,
):
super().__init__(d_out)
self.attn = nn.MultiheadAttention(
embed_dim=content_dim,
num_heads=num_heads,
dropout=dropout,
kdim=prefix_dim,
vdim=prefix_dim,
batch_first=True,
)
self.duration_grad_scale = duration_grad_scale
self.duration_predictor = duration_predictor
self.global_duration_mlp = nn.Sequential(
nn.Linear(content_dim, content_dim), nn.ReLU(),
nn.Dropout(dropout), nn.Linear(content_dim, 1)
)
self.norm = nn.LayerNorm(content_dim)
self.content_proj = nn.Conv1d(content_dim, d_out, 1)
def forward(self, content, content_mask, prefix, prefix_mask):
attn_output, attn_output_weights = self.attn(
query=content,
key=prefix,
value=prefix,
key_padding_mask=~prefix_mask.bool()
)
attn_output = attn_output * content_mask.unsqueeze(-1).float()
x = self.norm(attn_output + content)
x_grad_rescaled = x * self.duration_grad_scale + x.detach(
) * (1 - self.duration_grad_scale)
x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
).sum(dim=1) / content_mask.sum(dim=1,
keepdim=True).float()
global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
local_duration = self.duration_predictor(
x_grad_rescaled, content_mask
).squeeze(-1)
content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
return content, content_mask, global_duration, local_duration
class ExperimentalCrossAttentionAdapter(ContentAdapterBase):
def __init__(
self,
d_out: int,
content_dim: int,
prefix_dim: int,
num_heads: int,
duration_predictor: DurationPredictor,
dropout: float = 0.1,
duration_grad_scale: float = 0.1,
):
super().__init__(d_out)
self.content_mlp = nn.Sequential(
nn.Linear(content_dim, content_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(content_dim, content_dim),
)
self.content_norm = nn.LayerNorm(content_dim)
self.prefix_mlp = nn.Sequential(
nn.Linear(prefix_dim, prefix_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(prefix_dim, prefix_dim),
)
self.prefix_norm = nn.LayerNorm(content_dim)
self.attn = nn.MultiheadAttention(
embed_dim=content_dim,
num_heads=num_heads,
dropout=dropout,
kdim=prefix_dim,
vdim=prefix_dim,
batch_first=True,
)
self.duration_grad_scale = duration_grad_scale
self.duration_predictor = duration_predictor
self.global_duration_mlp = nn.Sequential(
nn.Linear(content_dim, content_dim), nn.ReLU(),
nn.Dropout(dropout), nn.Linear(content_dim, 1)
)
self.content_proj = nn.Sequential(
nn.Linear(content_dim, d_out),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_out, d_out),
)
self.norm1 = nn.LayerNorm(content_dim)
self.norm2 = nn.LayerNorm(d_out)
self.init_weights()
def init_weights(self):
def _init_weights(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0.)
self.apply(_init_weights)
def forward(self, content, content_mask, prefix, prefix_mask):
content = self.content_mlp(content)
content = self.content_norm(content)
prefix = self.prefix_mlp(prefix)
prefix = self.prefix_norm(prefix)
attn_output, attn_weights = self.attn(
query=content,
key=prefix,
value=prefix,
key_padding_mask=~prefix_mask.bool(),
)
attn_output = attn_output * content_mask.unsqueeze(-1).float()
x = attn_output + content
x = self.norm1(x)
x_grad_rescaled = x * self.duration_grad_scale + x.detach(
) * (1 - self.duration_grad_scale)
x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
).sum(dim=1) / content_mask.sum(dim=1,
keepdim=True).float()
global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
local_duration = self.duration_predictor(
x_grad_rescaled, content_mask
).squeeze(-1)
content = self.content_proj(x)
content = self.norm2(content)
return content, content_mask, global_duration, local_duration