Spaces:
Sleeping
Sleeping
| # Copyright 2024 Xiaomi Corp. (authors: Wei Kang | |
| # Han Zhu) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List | |
| import torch | |
| from zipvoice.models.modules.solver import DistillEulerSolver | |
| from zipvoice.models.modules.zipformer import TTSZipformer | |
| from zipvoice.models.zipvoice import ZipVoice | |
| class ZipVoiceDistill(ZipVoice): | |
| """ZipVoice-Distill model.""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| required_params = { | |
| "feat_dim", | |
| "fm_decoder_downsampling_factor", | |
| "fm_decoder_num_layers", | |
| "fm_decoder_cnn_module_kernel", | |
| "fm_decoder_dim", | |
| "fm_decoder_feedforward_dim", | |
| "fm_decoder_num_heads", | |
| "query_head_dim", | |
| "pos_head_dim", | |
| "value_head_dim", | |
| "pos_dim", | |
| "time_embed_dim", | |
| } | |
| missing = [p for p in required_params if p not in kwargs] | |
| if missing: | |
| raise ValueError(f"Missing required parameters: {', '.join(missing)}") | |
| self.fm_decoder = TTSZipformer( | |
| in_dim=kwargs["feat_dim"] * 3, | |
| out_dim=kwargs["feat_dim"], | |
| downsampling_factor=kwargs["fm_decoder_downsampling_factor"], | |
| num_encoder_layers=kwargs["fm_decoder_num_layers"], | |
| cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"], | |
| encoder_dim=kwargs["fm_decoder_dim"], | |
| feedforward_dim=kwargs["fm_decoder_feedforward_dim"], | |
| num_heads=kwargs["fm_decoder_num_heads"], | |
| query_head_dim=kwargs["query_head_dim"], | |
| pos_head_dim=kwargs["pos_head_dim"], | |
| value_head_dim=kwargs["value_head_dim"], | |
| pos_dim=kwargs["pos_dim"], | |
| use_time_embed=True, | |
| time_embed_dim=kwargs["time_embed_dim"], | |
| use_guidance_scale_embed=True, | |
| ) | |
| self.solver = DistillEulerSolver(self, func_name="forward_fm_decoder") | |
| def forward( | |
| self, | |
| tokens: List[List[int]], | |
| features: torch.Tensor, | |
| features_lens: torch.Tensor, | |
| noise: torch.Tensor, | |
| speech_condition_mask: torch.Tensor, | |
| t_start: float, | |
| t_end: float, | |
| num_step: int = 1, | |
| guidance_scale: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| return self.sample_intermediate( | |
| tokens=tokens, | |
| features=features, | |
| features_lens=features_lens, | |
| noise=noise, | |
| speech_condition_mask=speech_condition_mask, | |
| t_start=t_start, | |
| t_end=t_end, | |
| num_step=num_step, | |
| guidance_scale=guidance_scale, | |
| ) | |