STAR / models /content_encoder /content_encoder.py
Yixuan Li
first commit
4853fdc
from typing import Any
import torch
import torch.nn as nn
class ContentEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
text_encoder: nn.Module = None,
video_encoder: nn.Module = None,
midi_encoder: nn.Module = None,
phoneme_encoder: nn.Module = None,
pitch_encoder: nn.Module = None,
audio_encoder: nn.Module = None,
speech_encoder: nn.Module = None,
sketch_encoder: nn.Module = None,
):
super().__init__()
self.embed_dim = embed_dim
self.text_encoder = text_encoder
self.midi_encoder = midi_encoder
self.phoneme_encoder = phoneme_encoder
self.pitch_encoder = pitch_encoder
self.audio_encoder = audio_encoder
self.video_encoder = video_encoder
self.speech_encoder = speech_encoder
self.sketch_encoder = sketch_encoder
def encode_content(
self, batch_content: list[Any], batch_task: list[str],
device: str | torch.device
):
batch_content_output = []
batch_content_mask = []
batch_la_content_output = []
zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
for content, task in zip(batch_content, batch_task):
if task == "audio_super_resolution" or task == "speech_enhancement":
content_dict = {
"waveform": torch.as_tensor(content).float(),
"waveform_lengths": torch.as_tensor(content.shape[0]),
}
for key in list(content_dict.keys()):
content_dict[key] = content_dict[key].unsqueeze(0).to(
device
)
content_output_dict = self.audio_encoder(**content_dict)
la_content_output_dict = {
"output": zero_la_content,
}
elif task == "text_to_audio" or task == "text_to_music":
content_output_dict = self.text_encoder([content])
la_content_output_dict = {
"output": zero_la_content,
}
elif task == "speech_to_audio":
input_dict = {
"embed": content,
"embed_len": torch.tensor([content.shape[1]], dtype=torch.int).to(device),
}
content_output_dict = self.speech_encoder(input_dict)
la_content_output_dict = {
"output": zero_la_content,
}
elif task == "direct_speech_to_audio":
# content shape [1, L/T 133, dim] mask [1, L/T 133] in hubert
if len(content.shape) < 3:
content = content.unsqueeze(0)
mask = torch.ones(content.shape[:2])
mask = (mask == 1).to(content.device)
content_output_dict = {
"output": content,
"mask": mask,
}
la_content_output_dict = {
"output": zero_la_content,
}
elif task == "sketch_to_audio":
content_output_dict = self.sketch_encoder([content["caption"]])
content_dict = {
"f0": torch.as_tensor(content["f0"]),
"energy": torch.as_tensor(content["energy"]),
}
for key in list(content_dict.keys()):
content_dict[key] = content_dict[key].unsqueeze(0).to(
device
)
la_content_output_dict = self.sketch_encoder.encode_sketch(
**content_dict
)
elif task == "video_to_audio":
content_dict = {
"frames": torch.as_tensor(content).float(),
"frame_nums": torch.as_tensor(content.shape[0]),
}
for key in list(content_dict.keys()):
content_dict[key] = content_dict[key].unsqueeze(0).to(
device
)
content_output_dict = self.video_encoder(**content_dict)
la_content_output_dict = {
"output": zero_la_content,
}
elif task == "singing_voice_synthesis":
content_dict = {
"phoneme":
torch.as_tensor(content["phoneme"]).long(),
"midi":
torch.as_tensor(content["midi"]).long(),
"midi_duration":
torch.as_tensor(content["midi_duration"]).float(),
"is_slur":
torch.as_tensor(content["is_slur"]).long()
}
if "spk" in content:
if self.midi_encoder.spk_config.encoding_format == "id":
content_dict["spk"] = torch.as_tensor(content["spk"]
).long()
elif self.midi_encoder.spk_config.encoding_format == "embedding":
content_dict["spk"] = torch.as_tensor(content["spk"]
).float()
for key in list(content_dict.keys()):
content_dict[key] = content_dict[key].unsqueeze(0).to(
device
)
content_dict["lengths"] = torch.as_tensor([
len(content["phoneme"])
])
content_output_dict = self.midi_encoder(**content_dict)
la_content_output_dict = {"output": zero_la_content}
elif task == "text_to_speech":
content_dict = {
"phoneme": torch.as_tensor(content["phoneme"]).long(),
}
if "spk" in content:
if self.phoneme_encoder.spk_config.encoding_format == "id":
content_dict["spk"] = torch.as_tensor(content["spk"]
).long()
elif self.phoneme_encoder.spk_config.encoding_format == "embedding":
content_dict["spk"] = torch.as_tensor(content["spk"]
).float()
for key in list(content_dict.keys()):
content_dict[key] = content_dict[key].unsqueeze(0).to(
device
)
content_dict["lengths"] = torch.as_tensor([
len(content["phoneme"])
])
content_output_dict = self.phoneme_encoder(**content_dict)
la_content_output_dict = {"output": zero_la_content}
elif task == "singing_acoustic_modeling":
content_dict = {
"phoneme": torch.as_tensor(content["phoneme"]).long(),
}
for key in list(content_dict.keys()):
content_dict[key] = content_dict[key].unsqueeze(0).to(
device
)
content_dict["lengths"] = torch.as_tensor([
len(content["phoneme"])
])
content_output_dict = self.pitch_encoder(**content_dict)
content_dict = {
"f0": torch.as_tensor(content["f0"]),
"uv": torch.as_tensor(content["uv"]),
}
for key in list(content_dict.keys()):
content_dict[key] = content_dict[key].unsqueeze(0).to(
device
)
la_content_output_dict = self.pitch_encoder.encode_pitch(
**content_dict
)
batch_content_output.append(content_output_dict["output"][0])
batch_content_mask.append(content_output_dict["mask"][0])
batch_la_content_output.append(la_content_output_dict["output"][0])
batch_content_output = nn.utils.rnn.pad_sequence(
batch_content_output, batch_first=True, padding_value=0
)
batch_content_mask = nn.utils.rnn.pad_sequence(
batch_content_mask, batch_first=True, padding_value=False
)
batch_la_content_output = nn.utils.rnn.pad_sequence(
batch_la_content_output, batch_first=True, padding_value=0
)
return {
"content": batch_content_output,
"content_mask": batch_content_mask,
"length_aligned_content": batch_la_content_output,
}
class BatchedContentEncoder(ContentEncoder):
def encode_content(
self, batch_content: list | dict, batch_task: list[str],
device: str | torch.device
):
task = batch_task[0]
zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
if task == "audio_super_resolution" or task == "speech_enhancement":
content_dict = {
"waveform":
batch_content["content"].unsqueeze(1).float().to(device),
"waveform_lengths":
batch_content["content_lengths"].long().to(device),
}
content_output = self.audio_encoder(**content_dict)
la_content_output = zero_la_content
elif task == "text_to_audio":
content_output = self.text_encoder(batch_content)
la_content_output = zero_la_content
elif task == "video_to_audio":
content_dict = {
"frames":
batch_content["content"].float().to(device),
"frame_nums":
batch_content["content_lengths"].long().to(device),
}
content_output = self.video_encoder(**content_dict)
la_content_output = zero_la_content
elif task == "singing_voice_synthesis":
content_dict = {
"phoneme":
batch_content["phoneme"].long().to(device),
"midi":
batch_content["midi"].long().to(device),
"midi_duration":
batch_content["midi_duration"].float().to(device),
"is_slur":
batch_content["is_slur"].long().to(device),
"lengths":
batch_content["phoneme_lengths"].long().cpu(),
}
if "spk" in batch_content:
if self.midi_encoder.spk_config.encoding_format == "id":
content_dict["spk"] = batch_content["spk"].long(
).to(device)
elif self.midi_encoder.spk_config.encoding_format == "embedding":
content_dict["spk"] = batch_content["spk"].float(
).to(device)
content_output = self.midi_encoder(**content_dict)
la_content_output = zero_la_content
elif task == "text_to_speech":
content_dict = {
"phoneme": batch_content["phoneme"].long().to(device),
"lengths": batch_content["phoneme_lengths"].long().cpu(),
}
if "spk" in batch_content:
if self.phoneme_encoder.spk_config.encoding_format == "id":
content_dict["spk"] = batch_content["spk"].long(
).to(device)
elif self.phoneme_encoder.spk_config.encoding_format == "embedding":
content_dict["spk"] = batch_content["spk"].float(
).to(device)
content_output = self.phoneme_encoder(**content_dict)
la_content_output = zero_la_content
elif task == "singing_acoustic_modeling":
content_dict = {
"phoneme": batch_content["phoneme"].long().to(device),
"lengths": batch_content["phoneme_lengths"].long().to(device),
}
content_output = self.pitch_encoder(**content_dict)
content_dict = {
"f0": batch_content["f0"].float().to(device),
"uv": batch_content["uv"].float().to(device),
}
la_content_output = self.pitch_encoder.encode_pitch(**content_dict)
return {
"content": content_output["output"],
"content_mask": content_output["mask"],
"length_aligned_content": la_content_output,
}