Spaces:
Sleeping
Sleeping
| # Copyright 2025 Xiaomi Corp. (authors: Han Zhu) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from zipvoice.models.modules.zipformer_two_stream import TTSZipformerTwoStream | |
| from zipvoice.models.zipvoice import ZipVoice | |
| from zipvoice.utils.common import condition_time_mask_suffix, make_pad_mask, pad_labels | |
| class ZipVoiceDialog(ZipVoice): | |
| """The ZipVoice-Dialog model.""" | |
| def __init__( | |
| self, | |
| fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1], | |
| fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4], | |
| fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31], | |
| fm_decoder_feedforward_dim: int = 1536, | |
| fm_decoder_num_heads: int = 4, | |
| fm_decoder_dim: int = 512, | |
| text_encoder_num_layers: int = 4, | |
| text_encoder_feedforward_dim: int = 512, | |
| text_encoder_cnn_module_kernel: int = 9, | |
| text_encoder_num_heads: int = 4, | |
| text_encoder_dim: int = 192, | |
| time_embed_dim: int = 192, | |
| text_embed_dim: int = 192, | |
| query_head_dim: int = 32, | |
| value_head_dim: int = 12, | |
| pos_head_dim: int = 4, | |
| pos_dim: int = 48, | |
| feat_dim: int = 100, | |
| vocab_size: int = 26, | |
| pad_id: int = 0, | |
| spk_a_id: int = 360, | |
| spk_b_id: int = 361, | |
| ): | |
| """ | |
| Initialize the model with specified configuration parameters. | |
| Args: | |
| fm_decoder_downsampling_factor: List of downsampling factors for each layer | |
| in the flow-matching decoder. | |
| fm_decoder_num_layers: List of the number of layers for each block in the | |
| flow-matching decoder. | |
| fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the | |
| flow-matching decoder. | |
| fm_decoder_feedforward_dim: Dimension of the feedforward network in the | |
| flow-matching decoder. | |
| fm_decoder_num_heads: Number of attention heads in the flow-matching | |
| decoder. | |
| fm_decoder_dim: Hidden dimension of the flow-matching decoder. | |
| text_encoder_num_layers: Number of layers in the text encoder. | |
| text_encoder_feedforward_dim: Dimension of the feedforward network in the | |
| text encoder. | |
| text_encoder_cnn_module_kernel: Kernel size for the CNN module in the | |
| text encoder. | |
| text_encoder_num_heads: Number of attention heads in the text encoder. | |
| text_encoder_dim: Hidden dimension of the text encoder. | |
| time_embed_dim: Dimension of the time embedding. | |
| text_embed_dim: Dimension of the text embedding. | |
| query_head_dim: Dimension of the query attention head. | |
| value_head_dim: Dimension of the value attention head. | |
| pos_head_dim: Dimension of the position attention head. | |
| pos_dim: Dimension of the positional encoding. | |
| feat_dim: Dimension of the acoustic features. | |
| vocab_size: Size of the vocabulary. | |
| pad_id: ID used for padding tokens. | |
| spk_a_id: ID of speaker A / [S1]. | |
| spk_b_id: ID of speaker B / [S2]. | |
| """ | |
| super().__init__( | |
| fm_decoder_downsampling_factor=fm_decoder_downsampling_factor, | |
| fm_decoder_num_layers=fm_decoder_num_layers, | |
| fm_decoder_cnn_module_kernel=fm_decoder_cnn_module_kernel, | |
| fm_decoder_feedforward_dim=fm_decoder_feedforward_dim, | |
| fm_decoder_num_heads=fm_decoder_num_heads, | |
| fm_decoder_dim=fm_decoder_dim, | |
| text_encoder_num_layers=text_encoder_num_layers, | |
| text_encoder_feedforward_dim=text_encoder_feedforward_dim, | |
| text_encoder_cnn_module_kernel=text_encoder_cnn_module_kernel, | |
| text_encoder_num_heads=text_encoder_num_heads, | |
| text_encoder_dim=text_encoder_dim, | |
| time_embed_dim=time_embed_dim, | |
| text_embed_dim=text_embed_dim, | |
| query_head_dim=query_head_dim, | |
| value_head_dim=value_head_dim, | |
| pos_head_dim=pos_head_dim, | |
| pos_dim=pos_dim, | |
| feat_dim=feat_dim, | |
| vocab_size=vocab_size, | |
| pad_id=pad_id, | |
| ) | |
| self.spk_a_id = spk_a_id | |
| self.spk_b_id = spk_b_id | |
| self.spk_embed = nn.Embedding(2, feat_dim) | |
| torch.nn.init.normal_(self.spk_embed.weight, mean=0, std=0.1) | |
| def extract_spk_indices(self, tensor): | |
| turn_mask = ((tensor == self.spk_a_id) | (tensor == self.spk_b_id)).long() | |
| turn_counts = turn_mask.cumsum(dim=1) | |
| spk_mask = turn_counts % 2 | |
| spk_mask = torch.where(tensor == self.pad_id, -1, spk_mask) | |
| spk_a_indices = torch.where(spk_mask == 0) | |
| spk_b_indices = torch.where(spk_mask == 1) | |
| return spk_a_indices, spk_b_indices | |
| def forward_text_embed( | |
| self, | |
| tokens: List[List[int]], | |
| ): | |
| """ | |
| Get the text embeddings. | |
| Args: | |
| tokens: a list of list of token ids. | |
| Returns: | |
| embed: the text embeddings, shape (batch, seq_len, emb_dim). | |
| tokens_lens: the length of each token sequence, shape (batch,). | |
| """ | |
| device = ( | |
| self.device if isinstance(self, DDP) else next(self.parameters()).device | |
| ) | |
| tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S) | |
| embed = self.embed(tokens_padded) # (B, S, C) | |
| spk_a_indices, spk_b_indices = self.extract_spk_indices(tokens_padded) | |
| tokens_lens = torch.tensor( | |
| [len(token) for token in tokens], dtype=torch.int64, device=device | |
| ) | |
| tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S) | |
| embed = self.text_encoder( | |
| x=embed, t=None, padding_mask=tokens_padding_mask | |
| ) # (B, S, C) | |
| embed[spk_a_indices] += self.spk_embed(torch.tensor(0, device=device)).to( | |
| embed.dtype | |
| ) | |
| embed[spk_b_indices] += self.spk_embed(torch.tensor(1, device=device)).to( | |
| embed.dtype | |
| ) | |
| return embed, tokens_lens | |
| def forward( | |
| self, | |
| tokens: List[List[int]], | |
| features: torch.Tensor, | |
| features_lens: torch.Tensor, | |
| noise: torch.Tensor, | |
| t: torch.Tensor, | |
| condition_drop_ratio: float = 0.0, | |
| ) -> torch.Tensor: | |
| """Forward pass of the model for training. | |
| Args: | |
| tokens: a list of list of token ids. | |
| features: the acoustic features, with the shape (batch, seq_len, feat_dim). | |
| features_lens: the length of each acoustic feature sequence, shape (batch,). | |
| noise: the intitial noise, with the shape (batch, seq_len, feat_dim). | |
| t: the time step, with the shape (batch, 1, 1). | |
| condition_drop_ratio: the ratio of dropped text condition. | |
| Returns: | |
| fm_loss: the flow-matching loss. | |
| """ | |
| (text_condition, padding_mask,) = self.forward_text_train( | |
| tokens=tokens, | |
| features_lens=features_lens, | |
| ) | |
| speech_condition_mask = condition_time_mask_suffix( | |
| features_lens=features_lens, | |
| mask_percent=(0.5, 1.0), | |
| max_len=features.size(1), | |
| ) | |
| speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) | |
| if condition_drop_ratio > 0.0: | |
| drop_mask = ( | |
| torch.rand(text_condition.size(0), 1, 1).to(text_condition.device) | |
| > condition_drop_ratio | |
| ) | |
| text_condition = text_condition * drop_mask | |
| xt = features * t + noise * (1 - t) | |
| ut = features - noise # (B, T, F) | |
| vt = self.forward_fm_decoder( | |
| t=t, | |
| xt=xt, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| ) | |
| loss_mask = speech_condition_mask & (~padding_mask) | |
| fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2) | |
| return fm_loss | |
| class ZipVoiceDialogStereo(ZipVoiceDialog): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| required_params = { | |
| "feat_dim", | |
| "fm_decoder_downsampling_factor", | |
| "fm_decoder_num_layers", | |
| "fm_decoder_cnn_module_kernel", | |
| "fm_decoder_dim", | |
| "fm_decoder_feedforward_dim", | |
| "fm_decoder_num_heads", | |
| "query_head_dim", | |
| "pos_head_dim", | |
| "value_head_dim", | |
| "pos_dim", | |
| "time_embed_dim", | |
| } | |
| missing = [p for p in required_params if p not in kwargs] | |
| if missing: | |
| raise ValueError(f"Missing required parameters: {', '.join(missing)}") | |
| self.fm_decoder = TTSZipformerTwoStream( | |
| in_dim=(kwargs["feat_dim"] * 5, kwargs["feat_dim"] * 3), | |
| out_dim=(kwargs["feat_dim"] * 2, kwargs["feat_dim"]), | |
| downsampling_factor=kwargs["fm_decoder_downsampling_factor"], | |
| num_encoder_layers=kwargs["fm_decoder_num_layers"], | |
| cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"], | |
| encoder_dim=kwargs["fm_decoder_dim"], | |
| feedforward_dim=kwargs["fm_decoder_feedforward_dim"], | |
| num_heads=kwargs["fm_decoder_num_heads"], | |
| query_head_dim=kwargs["query_head_dim"], | |
| pos_head_dim=kwargs["pos_head_dim"], | |
| value_head_dim=kwargs["value_head_dim"], | |
| pos_dim=kwargs["pos_dim"], | |
| use_time_embed=True, | |
| time_embed_dim=kwargs["time_embed_dim"], | |
| ) | |
| def forward( | |
| self, | |
| tokens: List[List[int]], | |
| features: torch.Tensor, | |
| features_lens: torch.Tensor, | |
| noise: torch.Tensor, | |
| t: torch.Tensor, | |
| condition_drop_ratio: float = 0.0, | |
| se_weight: float = 1.0, | |
| ) -> torch.Tensor: | |
| """Forward pass of the model for training. | |
| Args: | |
| tokens: a list of list of token ids. | |
| features: the acoustic features, with the shape (batch, seq_len, feat_dim). | |
| features_lens: the length of each acoustic feature sequence, shape (batch,). | |
| noise: the intitial noise, with the shape (batch, seq_len, feat_dim). | |
| t: the time step, with the shape (batch, 1, 1). | |
| condition_drop_ratio: the ratio of dropped text condition. | |
| se_weight: the weight of the speaker exclusive loss. | |
| Returns: | |
| fm_loss: the flow-matching loss. | |
| """ | |
| (text_condition, padding_mask,) = self.forward_text_train( | |
| tokens=tokens, | |
| features_lens=features_lens, | |
| ) | |
| speech_condition_mask = condition_time_mask_suffix( | |
| features_lens=features_lens, | |
| mask_percent=(0.5, 1.0), | |
| max_len=features.size(1), | |
| ) | |
| speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) | |
| if condition_drop_ratio > 0.0: | |
| drop_mask = ( | |
| torch.rand(text_condition.size(0), 1, 1).to(text_condition.device) | |
| > condition_drop_ratio | |
| ) | |
| text_condition = text_condition * drop_mask | |
| xt = features * t + noise * (1 - t) | |
| ut = features - noise # (B, T, F) | |
| vt = self.forward_fm_decoder( | |
| t=t, | |
| xt=xt, | |
| text_condition=text_condition, | |
| speech_condition=speech_condition, | |
| padding_mask=padding_mask, | |
| ) | |
| loss_mask = speech_condition_mask & (~padding_mask) | |
| fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2) | |
| if se_weight > 0: | |
| target = xt + vt * (1 - t) | |
| fbank_1 = target[:, :, : self.feat_dim] | |
| fbank_2 = target[:, :, self.feat_dim :] | |
| energy_loss = torch.mean( | |
| self.energy_based_loss(fbank_1, fbank_2, features)[loss_mask] | |
| ) | |
| loss = fm_loss + energy_loss * se_weight | |
| else: | |
| loss = fm_loss | |
| return loss | |
| def energy_based_loss(self, fbank1, fbank2, gt_fbank): | |
| energy1 = self.energy(fbank1) | |
| energy2 = self.energy(fbank2) | |
| energy_thresholds = self.adaptive_threshold_from_gt( | |
| torch.cat( | |
| [ | |
| gt_fbank[:, :, : self.feat_dim], | |
| gt_fbank[:, :, self.feat_dim :], | |
| ], | |
| dim=1, | |
| ) | |
| ) | |
| both_speaking = ( | |
| (energy1 > energy_thresholds) & (energy2 > energy_thresholds) | |
| ).float() | |
| penalty = ( | |
| both_speaking | |
| * (energy1 - energy_thresholds) | |
| * (energy2 - energy_thresholds) | |
| ) | |
| return penalty | |
| def energy(self, fbank): | |
| return torch.mean(fbank, dim=-1) | |
| def adaptive_threshold_from_gt(self, gt_fbank, percentile=50): | |
| frame_energies = self.energy(gt_fbank) | |
| thresholds = torch.quantile(frame_energies, q=percentile / 100, dim=1) | |
| return thresholds.unsqueeze(1) | |