Spaces:
Runtime error
Runtime error
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| class ContentEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| text_encoder: nn.Module = None, | |
| video_encoder: nn.Module = None, | |
| midi_encoder: nn.Module = None, | |
| phoneme_encoder: nn.Module = None, | |
| pitch_encoder: nn.Module = None, | |
| audio_encoder: nn.Module = None, | |
| speech_encoder: nn.Module = None, | |
| sketch_encoder: nn.Module = None, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.text_encoder = text_encoder | |
| self.midi_encoder = midi_encoder | |
| self.phoneme_encoder = phoneme_encoder | |
| self.pitch_encoder = pitch_encoder | |
| self.audio_encoder = audio_encoder | |
| self.video_encoder = video_encoder | |
| self.speech_encoder = speech_encoder | |
| self.sketch_encoder = sketch_encoder | |
| def encode_content( | |
| self, batch_content: list[Any], batch_task: list[str], | |
| device: str | torch.device | |
| ): | |
| batch_content_output = [] | |
| batch_content_mask = [] | |
| batch_la_content_output = [] | |
| zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) | |
| for content, task in zip(batch_content, batch_task): | |
| if task == "audio_super_resolution" or task == "speech_enhancement": | |
| content_dict = { | |
| "waveform": torch.as_tensor(content).float(), | |
| "waveform_lengths": torch.as_tensor(content.shape[0]), | |
| } | |
| for key in list(content_dict.keys()): | |
| content_dict[key] = content_dict[key].unsqueeze(0).to( | |
| device | |
| ) | |
| content_output_dict = self.audio_encoder(**content_dict) | |
| la_content_output_dict = { | |
| "output": zero_la_content, | |
| } | |
| elif task == "text_to_audio" or task == "text_to_music": | |
| content_output_dict = self.text_encoder([content]) | |
| la_content_output_dict = { | |
| "output": zero_la_content, | |
| } | |
| elif task == "speech_to_audio": | |
| input_dict = { | |
| "embed": content, | |
| "embed_len": torch.tensor([content.shape[1]], dtype=torch.int).to(device), | |
| } | |
| content_output_dict = self.speech_encoder(input_dict) | |
| la_content_output_dict = { | |
| "output": zero_la_content, | |
| } | |
| elif task == "direct_speech_to_audio": | |
| # content shape [1, L/T 133, dim] mask [1, L/T 133] in hubert | |
| if len(content.shape) < 3: | |
| content = content.unsqueeze(0) | |
| mask = torch.ones(content.shape[:2]) | |
| mask = (mask == 1).to(content.device) | |
| content_output_dict = { | |
| "output": content, | |
| "mask": mask, | |
| } | |
| la_content_output_dict = { | |
| "output": zero_la_content, | |
| } | |
| elif task == "sketch_to_audio": | |
| content_output_dict = self.sketch_encoder([content["caption"]]) | |
| content_dict = { | |
| "f0": torch.as_tensor(content["f0"]), | |
| "energy": torch.as_tensor(content["energy"]), | |
| } | |
| for key in list(content_dict.keys()): | |
| content_dict[key] = content_dict[key].unsqueeze(0).to( | |
| device | |
| ) | |
| la_content_output_dict = self.sketch_encoder.encode_sketch( | |
| **content_dict | |
| ) | |
| elif task == "video_to_audio": | |
| content_dict = { | |
| "frames": torch.as_tensor(content).float(), | |
| "frame_nums": torch.as_tensor(content.shape[0]), | |
| } | |
| for key in list(content_dict.keys()): | |
| content_dict[key] = content_dict[key].unsqueeze(0).to( | |
| device | |
| ) | |
| content_output_dict = self.video_encoder(**content_dict) | |
| la_content_output_dict = { | |
| "output": zero_la_content, | |
| } | |
| elif task == "singing_voice_synthesis": | |
| content_dict = { | |
| "phoneme": | |
| torch.as_tensor(content["phoneme"]).long(), | |
| "midi": | |
| torch.as_tensor(content["midi"]).long(), | |
| "midi_duration": | |
| torch.as_tensor(content["midi_duration"]).float(), | |
| "is_slur": | |
| torch.as_tensor(content["is_slur"]).long() | |
| } | |
| if "spk" in content: | |
| if self.midi_encoder.spk_config.encoding_format == "id": | |
| content_dict["spk"] = torch.as_tensor(content["spk"] | |
| ).long() | |
| elif self.midi_encoder.spk_config.encoding_format == "embedding": | |
| content_dict["spk"] = torch.as_tensor(content["spk"] | |
| ).float() | |
| for key in list(content_dict.keys()): | |
| content_dict[key] = content_dict[key].unsqueeze(0).to( | |
| device | |
| ) | |
| content_dict["lengths"] = torch.as_tensor([ | |
| len(content["phoneme"]) | |
| ]) | |
| content_output_dict = self.midi_encoder(**content_dict) | |
| la_content_output_dict = {"output": zero_la_content} | |
| elif task == "text_to_speech": | |
| content_dict = { | |
| "phoneme": torch.as_tensor(content["phoneme"]).long(), | |
| } | |
| if "spk" in content: | |
| if self.phoneme_encoder.spk_config.encoding_format == "id": | |
| content_dict["spk"] = torch.as_tensor(content["spk"] | |
| ).long() | |
| elif self.phoneme_encoder.spk_config.encoding_format == "embedding": | |
| content_dict["spk"] = torch.as_tensor(content["spk"] | |
| ).float() | |
| for key in list(content_dict.keys()): | |
| content_dict[key] = content_dict[key].unsqueeze(0).to( | |
| device | |
| ) | |
| content_dict["lengths"] = torch.as_tensor([ | |
| len(content["phoneme"]) | |
| ]) | |
| content_output_dict = self.phoneme_encoder(**content_dict) | |
| la_content_output_dict = {"output": zero_la_content} | |
| elif task == "singing_acoustic_modeling": | |
| content_dict = { | |
| "phoneme": torch.as_tensor(content["phoneme"]).long(), | |
| } | |
| for key in list(content_dict.keys()): | |
| content_dict[key] = content_dict[key].unsqueeze(0).to( | |
| device | |
| ) | |
| content_dict["lengths"] = torch.as_tensor([ | |
| len(content["phoneme"]) | |
| ]) | |
| content_output_dict = self.pitch_encoder(**content_dict) | |
| content_dict = { | |
| "f0": torch.as_tensor(content["f0"]), | |
| "uv": torch.as_tensor(content["uv"]), | |
| } | |
| for key in list(content_dict.keys()): | |
| content_dict[key] = content_dict[key].unsqueeze(0).to( | |
| device | |
| ) | |
| la_content_output_dict = self.pitch_encoder.encode_pitch( | |
| **content_dict | |
| ) | |
| batch_content_output.append(content_output_dict["output"][0]) | |
| batch_content_mask.append(content_output_dict["mask"][0]) | |
| batch_la_content_output.append(la_content_output_dict["output"][0]) | |
| batch_content_output = nn.utils.rnn.pad_sequence( | |
| batch_content_output, batch_first=True, padding_value=0 | |
| ) | |
| batch_content_mask = nn.utils.rnn.pad_sequence( | |
| batch_content_mask, batch_first=True, padding_value=False | |
| ) | |
| batch_la_content_output = nn.utils.rnn.pad_sequence( | |
| batch_la_content_output, batch_first=True, padding_value=0 | |
| ) | |
| return { | |
| "content": batch_content_output, | |
| "content_mask": batch_content_mask, | |
| "length_aligned_content": batch_la_content_output, | |
| } | |
| class BatchedContentEncoder(ContentEncoder): | |
| def encode_content( | |
| self, batch_content: list | dict, batch_task: list[str], | |
| device: str | torch.device | |
| ): | |
| task = batch_task[0] | |
| zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) | |
| if task == "audio_super_resolution" or task == "speech_enhancement": | |
| content_dict = { | |
| "waveform": | |
| batch_content["content"].unsqueeze(1).float().to(device), | |
| "waveform_lengths": | |
| batch_content["content_lengths"].long().to(device), | |
| } | |
| content_output = self.audio_encoder(**content_dict) | |
| la_content_output = zero_la_content | |
| elif task == "text_to_audio": | |
| content_output = self.text_encoder(batch_content) | |
| la_content_output = zero_la_content | |
| elif task == "video_to_audio": | |
| content_dict = { | |
| "frames": | |
| batch_content["content"].float().to(device), | |
| "frame_nums": | |
| batch_content["content_lengths"].long().to(device), | |
| } | |
| content_output = self.video_encoder(**content_dict) | |
| la_content_output = zero_la_content | |
| elif task == "singing_voice_synthesis": | |
| content_dict = { | |
| "phoneme": | |
| batch_content["phoneme"].long().to(device), | |
| "midi": | |
| batch_content["midi"].long().to(device), | |
| "midi_duration": | |
| batch_content["midi_duration"].float().to(device), | |
| "is_slur": | |
| batch_content["is_slur"].long().to(device), | |
| "lengths": | |
| batch_content["phoneme_lengths"].long().cpu(), | |
| } | |
| if "spk" in batch_content: | |
| if self.midi_encoder.spk_config.encoding_format == "id": | |
| content_dict["spk"] = batch_content["spk"].long( | |
| ).to(device) | |
| elif self.midi_encoder.spk_config.encoding_format == "embedding": | |
| content_dict["spk"] = batch_content["spk"].float( | |
| ).to(device) | |
| content_output = self.midi_encoder(**content_dict) | |
| la_content_output = zero_la_content | |
| elif task == "text_to_speech": | |
| content_dict = { | |
| "phoneme": batch_content["phoneme"].long().to(device), | |
| "lengths": batch_content["phoneme_lengths"].long().cpu(), | |
| } | |
| if "spk" in batch_content: | |
| if self.phoneme_encoder.spk_config.encoding_format == "id": | |
| content_dict["spk"] = batch_content["spk"].long( | |
| ).to(device) | |
| elif self.phoneme_encoder.spk_config.encoding_format == "embedding": | |
| content_dict["spk"] = batch_content["spk"].float( | |
| ).to(device) | |
| content_output = self.phoneme_encoder(**content_dict) | |
| la_content_output = zero_la_content | |
| elif task == "singing_acoustic_modeling": | |
| content_dict = { | |
| "phoneme": batch_content["phoneme"].long().to(device), | |
| "lengths": batch_content["phoneme_lengths"].long().to(device), | |
| } | |
| content_output = self.pitch_encoder(**content_dict) | |
| content_dict = { | |
| "f0": batch_content["f0"].float().to(device), | |
| "uv": batch_content["uv"].float().to(device), | |
| } | |
| la_content_output = self.pitch_encoder.encode_pitch(**content_dict) | |
| return { | |
| "content": content_output["output"], | |
| "content_mask": content_output["mask"], | |
| "length_aligned_content": la_content_output, | |
| } | |