Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # Copyright 2025 Xiaomi Corp. (authors: Han Zhu) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| from torch import Tensor, nn | |
| from zipvoice.models.modules.scaling import FloatLike, ScheduledFloat, SwooshR | |
| from zipvoice.models.modules.zipformer import ( | |
| DownsampledZipformer2Encoder, | |
| TTSZipformer, | |
| Zipformer2Encoder, | |
| Zipformer2EncoderLayer, | |
| ) | |
| def timestep_embedding(timesteps, dim, max_period=10000): | |
| """Create sinusoidal timestep embeddings. | |
| :param timesteps: shape of (N) or (N, T) | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim) | |
| """ | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) | |
| * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) | |
| / half | |
| ) | |
| if timesteps.dim() == 2: | |
| timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N) | |
| args = timesteps[..., None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1) | |
| return embedding | |
| class TTSZipformerTwoStream(TTSZipformer): | |
| """ | |
| Args: | |
| Note: all "int or Tuple[int]" arguments below will be treated as lists of the same | |
| length as downsampling_factor if they are single ints or one-element tuples. | |
| The length of downsampling_factor defines the number of stacks. | |
| downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. | |
| Note: this is in addition to the downsampling factor of 2 that is applied in | |
| the frontend (self.encoder_embed). | |
| encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, | |
| one per encoder stack. | |
| num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack | |
| query_head_dim (int or Tuple[int]): dimension of query and key per attention | |
| head: per stack, if a tuple.. | |
| pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection | |
| per attention head | |
| value_head_dim (int or Tuple[int]): dimension of value in each attention head | |
| num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. | |
| Must be at least 4. | |
| feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules | |
| cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module | |
| pos_dim (int): the dimension of each positional-encoding vector prior to | |
| projection, e.g. 128. | |
| dropout (float): dropout rate | |
| warmup_batches (float): number of batches to warm up over; this controls | |
| dropout of encoder layers. | |
| use_time_embed: (bool): if True, do not take time embedding as additional input. | |
| time_embed_dim: (int): the dimension of the time embedding. | |
| """ | |
| def __init__( | |
| self, | |
| in_dim: Tuple[int], | |
| out_dim: Tuple[int], | |
| downsampling_factor: Tuple[int] = (2, 4), | |
| num_encoder_layers: Union[int, Tuple[int]] = 4, | |
| cnn_module_kernel: Union[int, Tuple[int]] = 31, | |
| encoder_dim: int = 384, | |
| query_head_dim: int = 24, | |
| pos_head_dim: int = 4, | |
| value_head_dim: int = 12, | |
| num_heads: int = 8, | |
| feedforward_dim: int = 1536, | |
| pos_dim: int = 192, | |
| dropout: FloatLike = None, # see code below for default | |
| warmup_batches: float = 4000.0, | |
| use_time_embed: bool = True, | |
| time_embed_dim: int = 192, | |
| use_conv: bool = True, | |
| ) -> None: | |
| nn.Module.__init__(self) | |
| if dropout is None: | |
| dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) | |
| if isinstance(downsampling_factor, int): | |
| downsampling_factor = (downsampling_factor,) | |
| def _to_tuple(x): | |
| """Converts a single int or a 1-tuple of an int to a tuple with the same | |
| length as downsampling_factor""" | |
| if isinstance(x, int): | |
| x = (x,) | |
| if len(x) == 1: | |
| x = x * len(downsampling_factor) | |
| else: | |
| assert len(x) == len(downsampling_factor) and isinstance(x[0], int) | |
| return x | |
| def _assert_downsampling_factor(factors): | |
| """assert downsampling_factor follows u-net style""" | |
| assert factors[0] == 1 and factors[-1] == 1 | |
| for i in range(1, len(factors) // 2 + 1): | |
| assert factors[i] == factors[i - 1] * 2 | |
| for i in range(len(factors) // 2 + 1, len(factors)): | |
| assert factors[i] * 2 == factors[i - 1] | |
| _assert_downsampling_factor(downsampling_factor) | |
| self.downsampling_factor = downsampling_factor # tuple | |
| num_encoder_layers = _to_tuple(num_encoder_layers) | |
| self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) | |
| self.encoder_dim = encoder_dim | |
| self.num_encoder_layers = num_encoder_layers | |
| self.query_head_dim = query_head_dim | |
| self.value_head_dim = value_head_dim | |
| self.num_heads = num_heads | |
| self.use_time_embed = use_time_embed | |
| self.time_embed_dim = time_embed_dim | |
| if self.use_time_embed: | |
| assert time_embed_dim != -1 | |
| else: | |
| time_embed_dim = -1 | |
| assert len(in_dim) == len(out_dim) == 2 | |
| self.in_dim = in_dim | |
| self.in_proj = nn.ModuleList( | |
| [nn.Linear(in_dim[0], encoder_dim), nn.Linear(in_dim[1], encoder_dim)] | |
| ) | |
| self.out_dim = out_dim | |
| self.out_proj = nn.ModuleList( | |
| [nn.Linear(encoder_dim, out_dim[0]), nn.Linear(encoder_dim, out_dim[1])] | |
| ) | |
| # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder | |
| encoders = [] | |
| num_encoders = len(downsampling_factor) | |
| for i in range(num_encoders): | |
| encoder_layer = Zipformer2EncoderLayer( | |
| embed_dim=encoder_dim, | |
| pos_dim=pos_dim, | |
| num_heads=num_heads, | |
| query_head_dim=query_head_dim, | |
| pos_head_dim=pos_head_dim, | |
| value_head_dim=value_head_dim, | |
| feedforward_dim=feedforward_dim, | |
| use_conv=use_conv, | |
| cnn_module_kernel=cnn_module_kernel[i], | |
| dropout=dropout, | |
| ) | |
| # For the segment of the warmup period, we let the Conv2dSubsampling | |
| # layer learn something. Then we start to warm up the other encoders. | |
| encoder = Zipformer2Encoder( | |
| encoder_layer, | |
| num_encoder_layers[i], | |
| embed_dim=encoder_dim, | |
| time_embed_dim=time_embed_dim, | |
| pos_dim=pos_dim, | |
| warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), | |
| warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), | |
| final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), | |
| ) | |
| if downsampling_factor[i] != 1: | |
| encoder = DownsampledZipformer2Encoder( | |
| encoder, | |
| dim=encoder_dim, | |
| downsample=downsampling_factor[i], | |
| ) | |
| encoders.append(encoder) | |
| self.encoders = nn.ModuleList(encoders) | |
| if self.use_time_embed: | |
| self.time_embed = nn.Sequential( | |
| nn.Linear(time_embed_dim, time_embed_dim * 2), | |
| SwooshR(), | |
| nn.Linear(time_embed_dim * 2, time_embed_dim), | |
| ) | |
| else: | |
| self.time_embed = None | |
| def forward( | |
| self, | |
| x: Tensor, | |
| t: Optional[Tensor] = None, | |
| padding_mask: Optional[Tensor] = None, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Args: | |
| x: | |
| The input tensor. Its shape is (batch_size, seq_len, feature_dim). | |
| t: | |
| A t tensor of shape (batch_size,) or (batch_size, seq_len) | |
| padding_mask: | |
| The mask for padding, of shape (batch_size, seq_len); True means | |
| masked position. May be None. | |
| Returns: | |
| Return the output embeddings. its shape is | |
| (batch_size, output_seq_len, encoder_dim) | |
| """ | |
| assert x.size(2) in self.in_dim, f"{x.size(2)} in {self.in_dim}" | |
| if x.size(2) == self.in_dim[0]: | |
| index = 0 | |
| else: | |
| index = 1 | |
| x = x.permute(1, 0, 2) | |
| x = self.in_proj[index](x) | |
| if t is not None: | |
| assert t.dim() == 1 or t.dim() == 2, t.shape | |
| time_emb = timestep_embedding(t, self.time_embed_dim) | |
| time_emb = self.time_embed(time_emb) | |
| else: | |
| time_emb = None | |
| attn_mask = None | |
| for i, module in enumerate(self.encoders): | |
| x = module( | |
| x, | |
| time_emb=time_emb, | |
| src_key_padding_mask=padding_mask, | |
| attn_mask=attn_mask, | |
| ) | |
| x = self.out_proj[index](x) | |
| x = x.permute(1, 0, 2) | |
| return x | |