Spaces:
wsntxxn
/
Running on Zero

Yixuan Li commited on
Commit
4853fdc
·
1 Parent(s): 5514739

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +98 -6
  2. ckpts/1m.pt +3 -0
  3. ckpts/exp0_best.pt +3 -0
  4. configs/config.yaml +84 -0
  5. configs/infer.yaml +16 -0
  6. models/__pycache__/common.cpython-310.pyc +0 -0
  7. models/__pycache__/content_adapter.cpython-310.pyc +0 -0
  8. models/__pycache__/diffusion.cpython-310.pyc +0 -0
  9. models/__pycache__/flow_matching.cpython-310.pyc +0 -0
  10. models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc +0 -0
  11. models/autoencoder/autoencoder_base.py +22 -0
  12. models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc +0 -0
  13. models/autoencoder/waveform/dac.py +0 -0
  14. models/autoencoder/waveform/stable_vae.py +559 -0
  15. models/common.py +67 -0
  16. models/content_adapter.py +381 -0
  17. models/content_encoder/__pycache__/content_encoder.cpython-310.pyc +0 -0
  18. models/content_encoder/__pycache__/sketch_encoder.cpython-310.pyc +0 -0
  19. models/content_encoder/__pycache__/text_encoder.cpython-310.pyc +0 -0
  20. models/content_encoder/content_encoder.py +280 -0
  21. models/content_encoder/midi_encoder.py +1046 -0
  22. models/content_encoder/sketch_encoder.py +51 -0
  23. models/content_encoder/star_encoder/__pycache__/Qformer.cpython-310.pyc +0 -0
  24. models/content_encoder/star_encoder/__pycache__/star_encoder.cpython-310.pyc +0 -0
  25. models/content_encoder/star_encoder/star_encoder.py +108 -0
  26. models/content_encoder/text_encoder.py +77 -0
  27. models/content_encoder/vision_encoder.py +34 -0
  28. models/diffsinger_net.py +119 -0
  29. models/diffusion.py +1261 -0
  30. models/dit/__pycache__/attention.cpython-310.pyc +0 -0
  31. models/dit/__pycache__/audio_dit.cpython-310.pyc +0 -0
  32. models/dit/__pycache__/mask_dit.cpython-310.pyc +0 -0
  33. models/dit/__pycache__/modules.cpython-310.pyc +0 -0
  34. models/dit/__pycache__/rotary.cpython-310.pyc +0 -0
  35. models/dit/__pycache__/span_mask.cpython-310.pyc +0 -0
  36. models/dit/attention.py +349 -0
  37. models/dit/audio_diffsingernet_dit.py +520 -0
  38. models/dit/audio_dit.py +652 -0
  39. models/dit/mask_dit.py +823 -0
  40. models/dit/modules.py +445 -0
  41. models/dit/rotary.py +88 -0
  42. models/dit/span_mask.py +149 -0
  43. models/flow_matching.py +1267 -0
  44. requirements.txt +149 -3
  45. utils/__pycache__/accelerate_utilities.cpython-310.pyc +0 -0
  46. utils/__pycache__/config.cpython-310.pyc +0 -0
  47. utils/__pycache__/diffsinger_utilities.cpython-310.pyc +0 -0
  48. utils/__pycache__/general.cpython-310.pyc +0 -0
  49. utils/__pycache__/logging.cpython-310.pyc +0 -0
  50. utils/__pycache__/lr_scheduler_utilities.cpython-310.pyc +0 -0
app.py CHANGED
@@ -1,26 +1,118 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
4
  gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
5
 
6
  gr.Markdown("""
7
  <div style="text-align: left; padding: 10px;">
 
8
  ## 🗣️ Input
9
 
10
  A brief input speech utterance for the overall audio scene.
11
 
12
- > ExampleA cat meowing and young female speaking
 
 
 
 
 
 
 
 
13
 
14
- #### 🎙️ Input Speech Example
 
 
 
15
 
16
- #### 🎧️ Output Audio Example
17
  </div>
18
  ---
19
  </div>
20
- """)
21
 
22
- if __name__ == "__main__":
23
- demo.launch()
 
 
 
 
 
24
 
25
 
26
 
 
1
  import gradio as gr
2
+ from pathlib import Path
3
+
4
+ import soundfile as sf
5
+ import torch
6
+ import torchaudio
7
+ import hydra
8
+ from omegaconf import OmegaConf
9
+ import diffusers.schedulers as noise_schedulers
10
+
11
+ from utils.config import register_omegaconf_resolvers
12
+ from models.common import LoadPretrainedBase
13
+
14
+ from huggingface_hub import hf_hub_download
15
+ import fairseq
16
+
17
+ register_omegaconf_resolvers()
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ config = OmegaConf.load("configs/infer.yaml")
20
+
21
+ ckpt_path = hf_hub_download(
22
+ repo_id="assasinatee/STAR",
23
+ filename="model.safetensors",
24
+ repo_type="model",
25
+ force_download=False
26
+ )
27
+
28
+ exp_config = OmegaConf.load("configs/config.yaml")
29
+ if "pretrained_ckpt" in exp_config["model"]:
30
+ exp_config["model"]["pretrained_ckpt"] = ckpt_path
31
+ model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
32
+
33
+ model = model.to(device)
34
+
35
+ ckpt_path = hf_hub_download(
36
+ repo_id="assasinatee/STAR",
37
+ filename="hubert_large_ll60k.pt",
38
+ repo_type="model",
39
+ force_download=False
40
+ )
41
+ hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
42
+ hubert_model = hubert_models[0].eval().to(device)
43
+
44
+ scheduler = getattr(
45
+ noise_schedulers,
46
+ config["noise_scheduler"]["type"],
47
+ ).from_pretrained(
48
+ config["noise_scheduler"]["name"],
49
+ subfolder="scheduler",
50
+ )
51
+
52
+ @torch.no_grad()
53
+ def infer(audio_path: str) -> str:
54
+ waveform_tts, sample_rate = torchaudio.load(audio_path)
55
+ if sample_rate != 16000:
56
+ waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
57
+ if waveform_tts.shape[0] > 1:
58
+ waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
59
+ with torch.no_grad():
60
+ features, _ = hubert_model.extract_features(waveform_tts.to(device))
61
+
62
+ kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
63
+ kwargs['content'] = [features]
64
+ kwargs['condition'] = None
65
+ kwargs['task'] = ["speech_to_audio"]
66
+
67
+ model.eval()
68
+ waveform = model.inference(
69
+ scheduler=scheduler,
70
+ **kwargs,
71
+ )
72
+
73
+ output_file = "output_audio.wav"
74
+ sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
75
+
76
+ return output_file
77
 
78
  with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
79
  gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
80
 
81
  gr.Markdown("""
82
  <div style="text-align: left; padding: 10px;">
83
+
84
  ## 🗣️ Input
85
 
86
  A brief input speech utterance for the overall audio scene.
87
 
88
+ > Example:A cat meowing and young female speaking
89
+
90
+ ### 🎙️ Input Speech Example
91
+ """)
92
+
93
+ speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
94
+
95
+ gr.Markdown("""
96
+ <div style="text-align: left; padding: 10px;">
97
 
98
+ ### 🎧️ Output Audio Example
99
+ """)
100
+
101
+ audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
102
 
103
+ gr.Markdown("""
104
  </div>
105
  ---
106
  </div>
107
+ """)
108
 
109
+ with gr.Column():
110
+ input_audio = gr.Audio(label="Speech Input", type="filepath")
111
+ btn = gr.Button("🎵Generate Audio!", variant="primary")
112
+ output_audio = gr.Audio(label="Generated Audio", type="filepath")
113
+ btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
114
+
115
+ demo.launch()
116
 
117
 
118
 
ckpts/1m.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cb13e2699fa922ce6a2b3b4f53c270ec64156e0cc3f3e3645e10cdf98b740dc
3
+ size 183037614
ckpts/exp0_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e2dc436e6d47cb02e954a0087a3a1b4aa1d5d3e1ded4fdafb6274966264d5a7
3
+ size 73171895
configs/config.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ autoencoder:
3
+ _target_: models.autoencoder.waveform.stable_vae.StableVAE
4
+ encoder:
5
+ _target_: models.autoencoder.waveform.stable_vae.OobleckEncoder
6
+ in_channels: 1
7
+ channels: 128
8
+ c_mults:
9
+ - 1
10
+ - 2
11
+ - 4
12
+ - 8
13
+ strides:
14
+ - 2
15
+ - 4
16
+ - 6
17
+ - 10
18
+ latent_dim: 256
19
+ use_snake: true
20
+ decoder:
21
+ _target_: models.autoencoder.waveform.stable_vae.OobleckDecoder
22
+ out_channels: 1
23
+ channels: 128
24
+ c_mults:
25
+ - 1
26
+ - 2
27
+ - 4
28
+ - 8
29
+ strides:
30
+ - 2
31
+ - 4
32
+ - 6
33
+ - 10
34
+ latent_dim: 128
35
+ use_snake: true
36
+ final_tanh: false
37
+ io_channels: 1
38
+ latent_dim: 128
39
+ downsampling_ratio: 480
40
+ sample_rate: 24000
41
+ pretrained_ckpt: /hpc_stor03/sjtu_home/xuenan.xu/workspace/text_to_audio_generation/ezaudio/ckpts/vae/1m.pt
42
+ bottleneck:
43
+ _target_: models.autoencoder.waveform.stable_vae.VAEBottleneck
44
+ backbone:
45
+ _target_: models.dit.mask_dit.UDiT
46
+ img_size: 500
47
+ patch_size: 1
48
+ in_chans: 128
49
+ out_chans: 128
50
+ input_type: 1d
51
+ embed_dim: 1024
52
+ depth: 24
53
+ num_heads: 16
54
+ mlp_ratio: 4.0
55
+ qkv_bias: false
56
+ qk_scale: null
57
+ qk_norm: layernorm
58
+ norm_layer: layernorm
59
+ act_layer: geglu
60
+ context_norm: true
61
+ use_checkpoint: true
62
+ time_fusion: ada_sola_bias
63
+ ada_sola_rank: 32
64
+ ada_sola_alpha: 32
65
+ cls_dim: null
66
+ context_dim: 1024
67
+ context_fusion: cross
68
+ context_max_length: null
69
+ context_pe_method: none
70
+ pe_method: none
71
+ rope_mode: shared
72
+ use_conv: true
73
+ skip: true
74
+ skip_norm: true
75
+ cfg_drop_ratio: 0.2
76
+ _target_: models.flow_matching.SingleTaskCrossAttentionAudioFlowMatching
77
+ content_encoder:
78
+ _target_: models.content_encoder.content_encoder.ContentEncoder
79
+ embed_dim: 1024
80
+ text_encoder: None
81
+ speech_encoder:
82
+ _target_: models.content_encoder.star_encoder.star_encoder.QformerBridgeNet
83
+ load_from_pretrained: /hpc_stor03/sjtu_home/zeyu.xie/workspace/speech2audio/hear/output/qformer_caption_tts_hubert/exp0_best.pt
84
+ pretrained_ckpt: /hpc_stor03/sjtu_home/zeyu.xie/workspace/speech2audio/x2audio/x_to_audio_generation/experiments/audiocaps_fm/checkpoints/epoch_100/model.safetensors
configs/infer.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - basic
3
+ - _self_
4
+
5
+ wav_dir: inference_delay
6
+
7
+ noise_scheduler:
8
+ type: DDIMScheduler
9
+ name: stabilityai/stable-diffusion-2-1
10
+
11
+ infer_args:
12
+ num_steps: 50
13
+ guidance_scale: 3.5
14
+ guidance_rescale: 0.5
15
+ use_gt_duration: false
16
+ latent_shape: [128, 500]
models/__pycache__/common.cpython-310.pyc ADDED
Binary file (2.99 kB). View file
 
models/__pycache__/content_adapter.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
models/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (24.6 kB). View file
 
models/__pycache__/flow_matching.cpython-310.pyc ADDED
Binary file (25.8 kB). View file
 
models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc ADDED
Binary file (1.05 kB). View file
 
models/autoencoder/autoencoder_base.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ from typing import Sequence
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class AutoEncoderBase(ABC):
8
+ def __init__(
9
+ self, downsampling_ratio: int, sample_rate: int,
10
+ latent_shape: Sequence[int | None]
11
+ ):
12
+ self.downsampling_ratio = downsampling_ratio
13
+ self.sample_rate = sample_rate
14
+ self.latent_token_rate = sample_rate // downsampling_ratio
15
+ self.latent_shape = latent_shape
16
+ self.time_dim = latent_shape.index(None) + 1 # the first dim is batch
17
+
18
+ @abstractmethod
19
+ def encode(
20
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
21
+ ) -> tuple[torch.Tensor, torch.Tensor]:
22
+ ...
models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
models/autoencoder/waveform/dac.py ADDED
File without changes
models/autoencoder/waveform/stable_vae.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Callable
2
+ import math
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.utils import weight_norm
8
+ import torchaudio
9
+ from alias_free_torch import Activation1d
10
+
11
+ from models.common import LoadPretrainedBase
12
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
13
+ from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length
14
+
15
+
16
+ # jit script make it 1.4x faster and save GPU memory
17
+ @torch.jit.script
18
+ def snake_beta(x, alpha, beta):
19
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
20
+
21
+
22
+ class SnakeBeta(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_features,
26
+ alpha=1.0,
27
+ alpha_trainable=True,
28
+ alpha_logscale=True
29
+ ):
30
+ super(SnakeBeta, self).__init__()
31
+ self.in_features = in_features
32
+
33
+ # initialize alpha
34
+ self.alpha_logscale = alpha_logscale
35
+ if self.alpha_logscale:
36
+ # log scale alphas initialized to zeros
37
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
38
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
39
+ else:
40
+ # linear scale alphas initialized to ones
41
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
42
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+ self.beta.requires_grad = alpha_trainable
46
+
47
+ # self.no_div_by_zero = 0.000000001
48
+
49
+ def forward(self, x):
50
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
51
+ # line up with x to [B, C, T]
52
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
53
+ if self.alpha_logscale:
54
+ alpha = torch.exp(alpha)
55
+ beta = torch.exp(beta)
56
+ x = snake_beta(x, alpha, beta)
57
+
58
+ return x
59
+
60
+
61
+ def WNConv1d(*args, **kwargs):
62
+ return weight_norm(nn.Conv1d(*args, **kwargs))
63
+
64
+
65
+ def WNConvTranspose1d(*args, **kwargs):
66
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
67
+
68
+
69
+ def get_activation(
70
+ activation: Literal["elu", "snake", "none"],
71
+ antialias=False,
72
+ channels=None
73
+ ) -> nn.Module:
74
+ if activation == "elu":
75
+ act = nn.ELU()
76
+ elif activation == "snake":
77
+ act = SnakeBeta(channels)
78
+ elif activation == "none":
79
+ act = nn.Identity()
80
+ else:
81
+ raise ValueError(f"Unknown activation {activation}")
82
+
83
+ if antialias:
84
+ act = Activation1d(act)
85
+
86
+ return act
87
+
88
+
89
+ class ResidualUnit(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels,
93
+ out_channels,
94
+ dilation,
95
+ use_snake=False,
96
+ antialias_activation=False
97
+ ):
98
+ super().__init__()
99
+
100
+ self.dilation = dilation
101
+
102
+ padding = (dilation * (7 - 1)) // 2
103
+
104
+ self.layers = nn.Sequential(
105
+ get_activation(
106
+ "snake" if use_snake else "elu",
107
+ antialias=antialias_activation,
108
+ channels=out_channels
109
+ ),
110
+ WNConv1d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=7,
114
+ dilation=dilation,
115
+ padding=padding
116
+ ),
117
+ get_activation(
118
+ "snake" if use_snake else "elu",
119
+ antialias=antialias_activation,
120
+ channels=out_channels
121
+ ),
122
+ WNConv1d(
123
+ in_channels=out_channels,
124
+ out_channels=out_channels,
125
+ kernel_size=1
126
+ )
127
+ )
128
+
129
+ def forward(self, x):
130
+ res = x
131
+
132
+ #x = checkpoint(self.layers, x)
133
+ x = self.layers(x)
134
+
135
+ return x + res
136
+
137
+
138
+ class EncoderBlock(nn.Module):
139
+ def __init__(
140
+ self,
141
+ in_channels,
142
+ out_channels,
143
+ stride,
144
+ use_snake=False,
145
+ antialias_activation=False
146
+ ):
147
+ super().__init__()
148
+
149
+ self.layers = nn.Sequential(
150
+ ResidualUnit(
151
+ in_channels=in_channels,
152
+ out_channels=in_channels,
153
+ dilation=1,
154
+ use_snake=use_snake
155
+ ),
156
+ ResidualUnit(
157
+ in_channels=in_channels,
158
+ out_channels=in_channels,
159
+ dilation=3,
160
+ use_snake=use_snake
161
+ ),
162
+ ResidualUnit(
163
+ in_channels=in_channels,
164
+ out_channels=in_channels,
165
+ dilation=9,
166
+ use_snake=use_snake
167
+ ),
168
+ get_activation(
169
+ "snake" if use_snake else "elu",
170
+ antialias=antialias_activation,
171
+ channels=in_channels
172
+ ),
173
+ WNConv1d(
174
+ in_channels=in_channels,
175
+ out_channels=out_channels,
176
+ kernel_size=2 * stride,
177
+ stride=stride,
178
+ padding=math.ceil(stride / 2)
179
+ ),
180
+ )
181
+
182
+ def forward(self, x):
183
+ return self.layers(x)
184
+
185
+
186
+ class DecoderBlock(nn.Module):
187
+ def __init__(
188
+ self,
189
+ in_channels,
190
+ out_channels,
191
+ stride,
192
+ use_snake=False,
193
+ antialias_activation=False,
194
+ use_nearest_upsample=False
195
+ ):
196
+ super().__init__()
197
+
198
+ if use_nearest_upsample:
199
+ upsample_layer = nn.Sequential(
200
+ nn.Upsample(scale_factor=stride, mode="nearest"),
201
+ WNConv1d(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ kernel_size=2 * stride,
205
+ stride=1,
206
+ bias=False,
207
+ padding='same'
208
+ )
209
+ )
210
+ else:
211
+ upsample_layer = WNConvTranspose1d(
212
+ in_channels=in_channels,
213
+ out_channels=out_channels,
214
+ kernel_size=2 * stride,
215
+ stride=stride,
216
+ padding=math.ceil(stride / 2)
217
+ )
218
+
219
+ self.layers = nn.Sequential(
220
+ get_activation(
221
+ "snake" if use_snake else "elu",
222
+ antialias=antialias_activation,
223
+ channels=in_channels
224
+ ),
225
+ upsample_layer,
226
+ ResidualUnit(
227
+ in_channels=out_channels,
228
+ out_channels=out_channels,
229
+ dilation=1,
230
+ use_snake=use_snake
231
+ ),
232
+ ResidualUnit(
233
+ in_channels=out_channels,
234
+ out_channels=out_channels,
235
+ dilation=3,
236
+ use_snake=use_snake
237
+ ),
238
+ ResidualUnit(
239
+ in_channels=out_channels,
240
+ out_channels=out_channels,
241
+ dilation=9,
242
+ use_snake=use_snake
243
+ ),
244
+ )
245
+
246
+ def forward(self, x):
247
+ return self.layers(x)
248
+
249
+
250
+ class OobleckEncoder(nn.Module):
251
+ def __init__(
252
+ self,
253
+ in_channels=2,
254
+ channels=128,
255
+ latent_dim=32,
256
+ c_mults=[1, 2, 4, 8],
257
+ strides=[2, 4, 8, 8],
258
+ use_snake=False,
259
+ antialias_activation=False
260
+ ):
261
+ super().__init__()
262
+
263
+ c_mults = [1] + c_mults
264
+
265
+ self.depth = len(c_mults)
266
+
267
+ layers = [
268
+ WNConv1d(
269
+ in_channels=in_channels,
270
+ out_channels=c_mults[0] * channels,
271
+ kernel_size=7,
272
+ padding=3
273
+ )
274
+ ]
275
+
276
+ for i in range(self.depth - 1):
277
+ layers += [
278
+ EncoderBlock(
279
+ in_channels=c_mults[i] * channels,
280
+ out_channels=c_mults[i + 1] * channels,
281
+ stride=strides[i],
282
+ use_snake=use_snake
283
+ )
284
+ ]
285
+
286
+ layers += [
287
+ get_activation(
288
+ "snake" if use_snake else "elu",
289
+ antialias=antialias_activation,
290
+ channels=c_mults[-1] * channels
291
+ ),
292
+ WNConv1d(
293
+ in_channels=c_mults[-1] * channels,
294
+ out_channels=latent_dim,
295
+ kernel_size=3,
296
+ padding=1
297
+ )
298
+ ]
299
+
300
+ self.layers = nn.Sequential(*layers)
301
+
302
+ def forward(self, x):
303
+ return self.layers(x)
304
+
305
+
306
+ class OobleckDecoder(nn.Module):
307
+ def __init__(
308
+ self,
309
+ out_channels=2,
310
+ channels=128,
311
+ latent_dim=32,
312
+ c_mults=[1, 2, 4, 8],
313
+ strides=[2, 4, 8, 8],
314
+ use_snake=False,
315
+ antialias_activation=False,
316
+ use_nearest_upsample=False,
317
+ final_tanh=True
318
+ ):
319
+ super().__init__()
320
+
321
+ c_mults = [1] + c_mults
322
+
323
+ self.depth = len(c_mults)
324
+
325
+ layers = [
326
+ WNConv1d(
327
+ in_channels=latent_dim,
328
+ out_channels=c_mults[-1] * channels,
329
+ kernel_size=7,
330
+ padding=3
331
+ ),
332
+ ]
333
+
334
+ for i in range(self.depth - 1, 0, -1):
335
+ layers += [
336
+ DecoderBlock(
337
+ in_channels=c_mults[i] * channels,
338
+ out_channels=c_mults[i - 1] * channels,
339
+ stride=strides[i - 1],
340
+ use_snake=use_snake,
341
+ antialias_activation=antialias_activation,
342
+ use_nearest_upsample=use_nearest_upsample
343
+ )
344
+ ]
345
+
346
+ layers += [
347
+ get_activation(
348
+ "snake" if use_snake else "elu",
349
+ antialias=antialias_activation,
350
+ channels=c_mults[0] * channels
351
+ ),
352
+ WNConv1d(
353
+ in_channels=c_mults[0] * channels,
354
+ out_channels=out_channels,
355
+ kernel_size=7,
356
+ padding=3,
357
+ bias=False
358
+ ),
359
+ nn.Tanh() if final_tanh else nn.Identity()
360
+ ]
361
+
362
+ self.layers = nn.Sequential(*layers)
363
+
364
+ def forward(self, x):
365
+ return self.layers(x)
366
+
367
+
368
+ class Bottleneck(nn.Module):
369
+ def __init__(self, is_discrete: bool = False):
370
+ super().__init__()
371
+
372
+ self.is_discrete = is_discrete
373
+
374
+ def encode(self, x, return_info=False, **kwargs):
375
+ raise NotImplementedError
376
+
377
+ def decode(self, x):
378
+ raise NotImplementedError
379
+
380
+
381
+ @torch.jit.script
382
+ def vae_sample(mean, scale) -> dict[str, torch.Tensor]:
383
+ stdev = nn.functional.softplus(scale) + 1e-4
384
+ var = stdev * stdev
385
+ logvar = torch.log(var)
386
+ latents = torch.randn_like(mean) * stdev + mean
387
+
388
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
389
+ return {"latents": latents, "kl": kl}
390
+
391
+
392
+ class VAEBottleneck(Bottleneck):
393
+ def __init__(self):
394
+ super().__init__(is_discrete=False)
395
+
396
+ def encode(self,
397
+ x,
398
+ return_info=False,
399
+ **kwargs) -> dict[str, torch.Tensor] | torch.Tensor:
400
+ mean, scale = x.chunk(2, dim=1)
401
+ sampled = vae_sample(mean, scale)
402
+
403
+ if return_info:
404
+ return sampled["latents"], {"kl": sampled["kl"]}
405
+ else:
406
+ return sampled["latents"]
407
+
408
+ def decode(self, x):
409
+ return x
410
+
411
+
412
+ def compute_mean_kernel(x, y):
413
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
414
+ return torch.exp(-kernel_input).mean()
415
+
416
+
417
+ class Pretransform(nn.Module):
418
+ def __init__(self, enable_grad, io_channels, is_discrete):
419
+ super().__init__()
420
+
421
+ self.is_discrete = is_discrete
422
+ self.io_channels = io_channels
423
+ self.encoded_channels = None
424
+ self.downsampling_ratio = None
425
+
426
+ self.enable_grad = enable_grad
427
+
428
+ def encode(self, x):
429
+ raise NotImplementedError
430
+
431
+ def decode(self, z):
432
+ raise NotImplementedError
433
+
434
+ def tokenize(self, x):
435
+ raise NotImplementedError
436
+
437
+ def decode_tokens(self, tokens):
438
+ raise NotImplementedError
439
+
440
+
441
+ class StableVAE(LoadPretrainedBase, AutoEncoderBase):
442
+ def __init__(
443
+ self,
444
+ encoder,
445
+ decoder,
446
+ latent_dim,
447
+ downsampling_ratio,
448
+ sample_rate,
449
+ io_channels=2,
450
+ bottleneck: Bottleneck = None,
451
+ pretransform: Pretransform = None,
452
+ in_channels=None,
453
+ out_channels=None,
454
+ soft_clip=False,
455
+ pretrained_ckpt: str | Path = None
456
+ ):
457
+ LoadPretrainedBase.__init__(self)
458
+ AutoEncoderBase.__init__(
459
+ self,
460
+ downsampling_ratio=downsampling_ratio,
461
+ sample_rate=sample_rate,
462
+ latent_shape=(latent_dim, None)
463
+ )
464
+
465
+ self.latent_dim = latent_dim
466
+ self.io_channels = io_channels
467
+ self.in_channels = io_channels
468
+ self.out_channels = io_channels
469
+ self.min_length = self.downsampling_ratio
470
+
471
+ if in_channels is not None:
472
+ self.in_channels = in_channels
473
+
474
+ if out_channels is not None:
475
+ self.out_channels = out_channels
476
+
477
+ self.bottleneck = bottleneck
478
+ self.encoder = encoder
479
+ self.decoder = decoder
480
+ self.pretransform = pretransform
481
+ self.soft_clip = soft_clip
482
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
483
+
484
+ self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory(
485
+ "autoencoder."
486
+ )
487
+ if pretrained_ckpt is not None:
488
+ self.load_pretrained(pretrained_ckpt)
489
+
490
+ def process_state_dict(self, model_dict, state_dict):
491
+ state_dict = state_dict["state_dict"]
492
+ state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict)
493
+ return state_dict
494
+
495
+ def encode(
496
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
497
+ ) -> tuple[torch.Tensor, torch.Tensor]:
498
+ z = self.encoder(waveform)
499
+ z = self.bottleneck.encode(z)
500
+ z_length = waveform_lengths // self.downsampling_ratio
501
+ z_mask = create_mask_from_length(z_length)
502
+ return z, z_mask
503
+
504
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
505
+ waveform = self.decoder(latents)
506
+ return waveform
507
+
508
+
509
+ class StableVAEProjectorWrapper(nn.Module):
510
+ def __init__(
511
+ self,
512
+ vae_dim: int,
513
+ embed_dim: int,
514
+ model: StableVAE | None = None,
515
+ ):
516
+ super().__init__()
517
+ self.model = model
518
+ self.proj = nn.Linear(vae_dim, embed_dim)
519
+
520
+ def forward(
521
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
522
+ ) -> tuple[torch.Tensor, torch.Tensor]:
523
+ self.model.eval()
524
+ with torch.no_grad():
525
+ z, z_mask = self.model.encode(waveform, waveform_lengths)
526
+ z = self.proj(z.transpose(1, 2))
527
+ return {"output": z, "mask": z_mask}
528
+
529
+
530
+ if __name__ == '__main__':
531
+ import hydra
532
+ from utils.config import generate_config_from_command_line_overrides
533
+ model_config = generate_config_from_command_line_overrides(
534
+ "configs/model/autoencoder/stable_vae.yaml"
535
+ )
536
+ autoencoder: StableVAE = hydra.utils.instantiate(model_config)
537
+ autoencoder.eval()
538
+
539
+ waveform, sr = torchaudio.load(
540
+ "/hpc_stor03/sjtu_home/xuenan.xu/data/m4singer/Tenor-1#童话/0006.wav"
541
+ )
542
+ waveform = waveform.mean(0, keepdim=True)
543
+ waveform = torchaudio.functional.resample(
544
+ waveform, sr, model_config["sample_rate"]
545
+ )
546
+ print("waveform: ", waveform.shape)
547
+ with torch.no_grad():
548
+ latent, latent_length = autoencoder.encode(
549
+ waveform, torch.as_tensor([waveform.shape[-1]])
550
+ )
551
+ print("latent: ", latent.shape)
552
+ reconstructed = autoencoder.decode(latent)
553
+ print("reconstructed: ", reconstructed.shape)
554
+ import soundfile as sf
555
+ sf.write(
556
+ "./reconstructed.wav",
557
+ reconstructed[0, 0].numpy(),
558
+ samplerate=model_config["sample_rate"]
559
+ )
models/common.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import torch.nn as nn
4
+ from utils.torch_utilities import load_pretrained_model, merge_matched_keys
5
+
6
+
7
+ class LoadPretrainedBase(nn.Module):
8
+ def process_state_dict(
9
+ self, model_dict: dict[str, torch.Tensor],
10
+ state_dict: dict[str, torch.Tensor]
11
+ ):
12
+ """
13
+ Custom processing functions of each model that transforms `state_dict` loaded from
14
+ checkpoints to the state that can be used in `load_state_dict`.
15
+ Use `merge_mathced_keys` to update parameters with matched names and shapes by
16
+ default.
17
+
18
+ Args
19
+ model_dict:
20
+ The state dict of the current model, which is going to load pretrained parameters
21
+ state_dict:
22
+ A dictionary of parameters from a pre-trained model.
23
+
24
+ Returns:
25
+ dict[str, torch.Tensor]:
26
+ The updated state dict, where parameters with matched keys and shape are
27
+ updated with values in `state_dict`.
28
+ """
29
+ state_dict = merge_matched_keys(model_dict, state_dict)
30
+ return state_dict
31
+
32
+ def load_pretrained(self, ckpt_path: str | Path):
33
+ load_pretrained_model(
34
+ self, ckpt_path, state_dict_process_fn=self.process_state_dict
35
+ )
36
+
37
+
38
+ class CountParamsBase(nn.Module):
39
+ def count_params(self):
40
+ num_params = 0
41
+ trainable_params = 0
42
+ for param in self.parameters():
43
+ num_params += param.numel()
44
+ if param.requires_grad:
45
+ trainable_params += param.numel()
46
+ return num_params, trainable_params
47
+
48
+
49
+ class SaveTrainableParamsBase(nn.Module):
50
+ @property
51
+ def param_names_to_save(self):
52
+ names = []
53
+ for name, param in self.named_parameters():
54
+ if param.requires_grad:
55
+ names.append(name)
56
+ for name, _ in self.named_buffers():
57
+ names.append(name)
58
+ return names
59
+
60
+ def load_state_dict(self, state_dict, strict=True):
61
+ for key in self.param_names_to_save:
62
+ if key not in state_dict:
63
+ raise Exception(
64
+ f"{key} not found in either pre-trained models (e.g. BERT)"
65
+ " or resumed checkpoints (e.g. epoch_40/model.pt)"
66
+ )
67
+ return super().load_state_dict(state_dict, strict)
models/content_adapter.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from utils.torch_utilities import concat_non_padding, restore_from_concat
6
+
7
+
8
+ ######################
9
+ # fastspeech modules
10
+ ######################
11
+ class LayerNorm(nn.LayerNorm):
12
+ """Layer normalization module.
13
+ :param int nout: output dim size
14
+ :param int dim: dimension to be normalized
15
+ """
16
+ def __init__(self, nout, dim=-1):
17
+ """Construct an LayerNorm object."""
18
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
19
+ self.dim = dim
20
+
21
+ def forward(self, x):
22
+ """Apply layer normalization.
23
+ :param torch.Tensor x: input tensor
24
+ :return: layer normalized tensor
25
+ :rtype torch.Tensor
26
+ """
27
+ if self.dim == -1:
28
+ return super(LayerNorm, self).forward(x)
29
+ return super(LayerNorm,
30
+ self).forward(x.transpose(1, -1)).transpose(1, -1)
31
+
32
+
33
+ class DurationPredictor(nn.Module):
34
+ def __init__(
35
+ self,
36
+ in_channels: int,
37
+ filter_channels: int,
38
+ n_layers: int = 2,
39
+ kernel_size: int = 3,
40
+ p_dropout: float = 0.1,
41
+ padding: str = "SAME"
42
+ ):
43
+ super(DurationPredictor, self).__init__()
44
+ self.conv = nn.ModuleList()
45
+ self.kernel_size = kernel_size
46
+ self.padding = padding
47
+ for idx in range(n_layers):
48
+ in_chans = in_channels if idx == 0 else filter_channels
49
+ self.conv += [
50
+ nn.Sequential(
51
+ nn.ConstantPad1d(((kernel_size - 1) // 2,
52
+ (kernel_size - 1) //
53
+ 2) if padding == 'SAME' else
54
+ (kernel_size - 1, 0), 0),
55
+ nn.Conv1d(
56
+ in_chans,
57
+ filter_channels,
58
+ kernel_size,
59
+ stride=1,
60
+ padding=0
61
+ ), nn.ReLU(), LayerNorm(filter_channels, dim=1),
62
+ nn.Dropout(p_dropout)
63
+ )
64
+ ]
65
+ self.linear = nn.Linear(filter_channels, 1)
66
+
67
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
68
+ # x: [B, T, E]
69
+ x = x.transpose(1, -1)
70
+ x_mask = x_mask.unsqueeze(1).to(x.device)
71
+ for f in self.conv:
72
+ x = f(x)
73
+ x = x * x_mask.float()
74
+
75
+ x = self.linear(x.transpose(1, -1)
76
+ ) * x_mask.transpose(1, -1).float() # [B, T, 1]
77
+ return x
78
+
79
+
80
+ ######################
81
+ # adapter modules
82
+ ######################
83
+
84
+
85
+ class ContentAdapterBase(nn.Module):
86
+ def __init__(self, d_out):
87
+ super().__init__()
88
+ self.d_out = d_out
89
+
90
+
91
+ class SinusoidalPositionalEmbedding(nn.Module):
92
+ def __init__(self, d_model, dropout, max_len=1000):
93
+ super().__init__()
94
+ self.dropout = nn.Dropout(dropout)
95
+ pe = torch.zeros(max_len, d_model)
96
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
97
+ div_term = torch.exp(
98
+ torch.arange(0, d_model, 2).float() *
99
+ (-math.log(10000.0) / d_model)
100
+ )
101
+ pe[:, 0::2] = torch.sin(position * div_term)
102
+ pe[:, 1::2] = torch.cos(position * div_term)
103
+ pe = pe.unsqueeze(0).transpose(0, 1)
104
+ self.register_buffer('pe', pe)
105
+
106
+ def forward(self, x):
107
+ x = x + self.pe[:x.size(1), :]
108
+ return self.dropout(x)
109
+
110
+
111
+ class ContentAdapter(ContentAdapterBase):
112
+ def __init__(
113
+ self,
114
+ d_model: int,
115
+ d_out: int,
116
+ num_layers: int,
117
+ num_heads: int,
118
+ duration_predictor: DurationPredictor,
119
+ dropout: float = 0.1,
120
+ norm_first: bool = False,
121
+ activation: str = "gelu",
122
+ duration_grad_scale: float = 0.0,
123
+ ):
124
+ super().__init__(d_out)
125
+ self.duration_grad_scale = duration_grad_scale
126
+ self.cls_embed = nn.Parameter(torch.randn(d_model))
127
+ if hasattr(torch, "npu") and torch.npu.is_available():
128
+ enable_nested_tensor = False
129
+ else:
130
+ enable_nested_tensor = True
131
+ encoder_layer = nn.TransformerEncoderLayer(
132
+ d_model=d_model,
133
+ nhead=num_heads,
134
+ dim_feedforward=4 * d_model,
135
+ dropout=dropout,
136
+ activation=activation,
137
+ norm_first=norm_first,
138
+ batch_first=True
139
+ )
140
+ self.encoder_layers = nn.TransformerEncoder(
141
+ encoder_layer=encoder_layer,
142
+ num_layers=num_layers,
143
+ enable_nested_tensor=enable_nested_tensor
144
+ )
145
+ self.duration_predictor = duration_predictor
146
+ self.content_proj = nn.Conv1d(d_model, d_out, 1)
147
+
148
+ def forward(self, x, x_mask):
149
+ batch_size = x.size(0)
150
+ cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
151
+ cls_embed = cls_embed.to(x.device).unsqueeze(1)
152
+ x = torch.cat([cls_embed, x], dim=1)
153
+
154
+ cls_mask = torch.ones(batch_size, 1).to(x_mask.device)
155
+ x_mask = torch.cat([cls_mask, x_mask], dim=1)
156
+ x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool())
157
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
158
+ ) * (1 - self.duration_grad_scale)
159
+ duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
160
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
161
+ return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
162
+
163
+
164
+ class PrefixAdapter(ContentAdapterBase):
165
+ def __init__(
166
+ self,
167
+ content_dim: int,
168
+ d_model: int,
169
+ d_out: int,
170
+ prefix_dim: int,
171
+ num_layers: int,
172
+ num_heads: int,
173
+ duration_predictor: DurationPredictor,
174
+ dropout: float = 0.1,
175
+ norm_first: bool = False,
176
+ use_last_norm: bool = True,
177
+ activation: str = "gelu",
178
+ duration_grad_scale: float = 0.1,
179
+ ):
180
+ super().__init__(d_out)
181
+ self.duration_grad_scale = duration_grad_scale
182
+ self.prefix_mlp = nn.Sequential(
183
+ nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
184
+ nn.Linear(d_model, d_model)
185
+ )
186
+ self.content_mlp = nn.Sequential(
187
+ nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
188
+ nn.Linear(d_model, d_model)
189
+ )
190
+ layer = nn.TransformerEncoderLayer(
191
+ d_model=d_model,
192
+ nhead=num_heads,
193
+ dim_feedforward=4 * d_model,
194
+ dropout=dropout,
195
+ activation=activation,
196
+ batch_first=True,
197
+ norm_first=norm_first
198
+ )
199
+ if hasattr(torch, "npu") and torch.npu.is_available():
200
+ enable_nested_tensor = False
201
+ else:
202
+ enable_nested_tensor = True
203
+ self.cls_embed = nn.Parameter(torch.randn(d_model))
204
+ # self.pos_embed = SinusoidalPositionalEmbedding(d_model, dropout)
205
+ self.layers = nn.TransformerEncoder(
206
+ encoder_layer=layer,
207
+ num_layers=num_layers,
208
+ enable_nested_tensor=enable_nested_tensor
209
+ )
210
+ self.use_last_norm = use_last_norm
211
+ if self.use_last_norm:
212
+ self.last_norm = nn.LayerNorm(d_model)
213
+ self.duration_predictor = duration_predictor
214
+ self.content_proj = nn.Conv1d(d_model, d_out, 1)
215
+ nn.init.normal_(self.cls_embed, 0., 0.02)
216
+ nn.init.xavier_uniform_(self.content_proj.weight)
217
+ nn.init.constant_(self.content_proj.bias, 0.)
218
+
219
+ def forward(self, content, content_mask, instruction, instruction_mask):
220
+ batch_size = content.size(0)
221
+ cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
222
+ cls_embed = cls_embed.to(content.device).unsqueeze(1)
223
+ content = self.content_mlp(content)
224
+ x = torch.cat([cls_embed, content], dim=1)
225
+ cls_mask = torch.ones(batch_size, 1,
226
+ dtype=bool).to(content_mask.device)
227
+ x_mask = torch.cat([cls_mask, content_mask], dim=1)
228
+
229
+ prefix = self.prefix_mlp(instruction)
230
+ seq, seq_mask, perm = concat_non_padding(
231
+ prefix, instruction_mask, x, x_mask
232
+ )
233
+ # seq = self.pos_embed(seq)
234
+ x = self.layers(seq, src_key_padding_mask=~seq_mask.bool())
235
+ if self.use_last_norm:
236
+ x = self.last_norm(x)
237
+ _, x = restore_from_concat(x, instruction_mask, x_mask, perm)
238
+
239
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
240
+ ) * (1 - self.duration_grad_scale)
241
+ duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
242
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
243
+ return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
244
+
245
+
246
+ class CrossAttentionAdapter(ContentAdapterBase):
247
+ def __init__(
248
+ self,
249
+ d_out: int,
250
+ content_dim: int,
251
+ prefix_dim: int,
252
+ num_heads: int,
253
+ duration_predictor: DurationPredictor,
254
+ dropout: float = 0.1,
255
+ duration_grad_scale: float = 0.1,
256
+ ):
257
+ super().__init__(d_out)
258
+ self.attn = nn.MultiheadAttention(
259
+ embed_dim=content_dim,
260
+ num_heads=num_heads,
261
+ dropout=dropout,
262
+ kdim=prefix_dim,
263
+ vdim=prefix_dim,
264
+ batch_first=True,
265
+ )
266
+ self.duration_grad_scale = duration_grad_scale
267
+ self.duration_predictor = duration_predictor
268
+ self.global_duration_mlp = nn.Sequential(
269
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
270
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
271
+ )
272
+ self.norm = nn.LayerNorm(content_dim)
273
+ self.content_proj = nn.Conv1d(content_dim, d_out, 1)
274
+
275
+ def forward(self, content, content_mask, prefix, prefix_mask):
276
+ attn_output, attn_output_weights = self.attn(
277
+ query=content,
278
+ key=prefix,
279
+ value=prefix,
280
+ key_padding_mask=~prefix_mask.bool()
281
+ )
282
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
283
+ x = self.norm(attn_output + content)
284
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
285
+ ) * (1 - self.duration_grad_scale)
286
+ x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
287
+ ).sum(dim=1) / content_mask.sum(dim=1,
288
+ keepdim=True).float()
289
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
290
+ local_duration = self.duration_predictor(
291
+ x_grad_rescaled, content_mask
292
+ ).squeeze(-1)
293
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
294
+ return content, content_mask, global_duration, local_duration
295
+
296
+
297
+ class ExperimentalCrossAttentionAdapter(ContentAdapterBase):
298
+ def __init__(
299
+ self,
300
+ d_out: int,
301
+ content_dim: int,
302
+ prefix_dim: int,
303
+ num_heads: int,
304
+ duration_predictor: DurationPredictor,
305
+ dropout: float = 0.1,
306
+ duration_grad_scale: float = 0.1,
307
+ ):
308
+ super().__init__(d_out)
309
+ self.content_mlp = nn.Sequential(
310
+ nn.Linear(content_dim, content_dim),
311
+ nn.ReLU(),
312
+ nn.Dropout(dropout),
313
+ nn.Linear(content_dim, content_dim),
314
+ )
315
+ self.content_norm = nn.LayerNorm(content_dim)
316
+ self.prefix_mlp = nn.Sequential(
317
+ nn.Linear(prefix_dim, prefix_dim),
318
+ nn.ReLU(),
319
+ nn.Dropout(dropout),
320
+ nn.Linear(prefix_dim, prefix_dim),
321
+ )
322
+ self.prefix_norm = nn.LayerNorm(content_dim)
323
+ self.attn = nn.MultiheadAttention(
324
+ embed_dim=content_dim,
325
+ num_heads=num_heads,
326
+ dropout=dropout,
327
+ kdim=prefix_dim,
328
+ vdim=prefix_dim,
329
+ batch_first=True,
330
+ )
331
+ self.duration_grad_scale = duration_grad_scale
332
+ self.duration_predictor = duration_predictor
333
+ self.global_duration_mlp = nn.Sequential(
334
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
335
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
336
+ )
337
+ self.content_proj = nn.Sequential(
338
+ nn.Linear(content_dim, d_out),
339
+ nn.ReLU(),
340
+ nn.Dropout(dropout),
341
+ nn.Linear(d_out, d_out),
342
+ )
343
+ self.norm1 = nn.LayerNorm(content_dim)
344
+ self.norm2 = nn.LayerNorm(d_out)
345
+ self.init_weights()
346
+
347
+ def init_weights(self):
348
+ def _init_weights(module):
349
+ if isinstance(module, nn.Linear):
350
+ nn.init.xavier_uniform_(module.weight)
351
+ if module.bias is not None:
352
+ nn.init.constant_(module.bias, 0.)
353
+
354
+ self.apply(_init_weights)
355
+
356
+ def forward(self, content, content_mask, prefix, prefix_mask):
357
+ content = self.content_mlp(content)
358
+ content = self.content_norm(content)
359
+ prefix = self.prefix_mlp(prefix)
360
+ prefix = self.prefix_norm(prefix)
361
+ attn_output, attn_weights = self.attn(
362
+ query=content,
363
+ key=prefix,
364
+ value=prefix,
365
+ key_padding_mask=~prefix_mask.bool(),
366
+ )
367
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
368
+ x = attn_output + content
369
+ x = self.norm1(x)
370
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
371
+ ) * (1 - self.duration_grad_scale)
372
+ x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
373
+ ).sum(dim=1) / content_mask.sum(dim=1,
374
+ keepdim=True).float()
375
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
376
+ local_duration = self.duration_predictor(
377
+ x_grad_rescaled, content_mask
378
+ ).squeeze(-1)
379
+ content = self.content_proj(x)
380
+ content = self.norm2(content)
381
+ return content, content_mask, global_duration, local_duration
models/content_encoder/__pycache__/content_encoder.cpython-310.pyc ADDED
Binary file (5.51 kB). View file
 
models/content_encoder/__pycache__/sketch_encoder.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
models/content_encoder/__pycache__/text_encoder.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
models/content_encoder/content_encoder.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class ContentEncoder(nn.Module):
7
+ def __init__(
8
+ self,
9
+ embed_dim: int,
10
+ text_encoder: nn.Module = None,
11
+ video_encoder: nn.Module = None,
12
+ midi_encoder: nn.Module = None,
13
+ phoneme_encoder: nn.Module = None,
14
+ pitch_encoder: nn.Module = None,
15
+ audio_encoder: nn.Module = None,
16
+ speech_encoder: nn.Module = None,
17
+ sketch_encoder: nn.Module = None,
18
+ ):
19
+ super().__init__()
20
+ self.embed_dim = embed_dim
21
+ self.text_encoder = text_encoder
22
+ self.midi_encoder = midi_encoder
23
+ self.phoneme_encoder = phoneme_encoder
24
+ self.pitch_encoder = pitch_encoder
25
+ self.audio_encoder = audio_encoder
26
+ self.video_encoder = video_encoder
27
+ self.speech_encoder = speech_encoder
28
+ self.sketch_encoder = sketch_encoder
29
+
30
+ def encode_content(
31
+ self, batch_content: list[Any], batch_task: list[str],
32
+ device: str | torch.device
33
+ ):
34
+ batch_content_output = []
35
+ batch_content_mask = []
36
+ batch_la_content_output = []
37
+
38
+ zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
39
+
40
+ for content, task in zip(batch_content, batch_task):
41
+ if task == "audio_super_resolution" or task == "speech_enhancement":
42
+ content_dict = {
43
+ "waveform": torch.as_tensor(content).float(),
44
+ "waveform_lengths": torch.as_tensor(content.shape[0]),
45
+ }
46
+ for key in list(content_dict.keys()):
47
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
48
+ device
49
+ )
50
+ content_output_dict = self.audio_encoder(**content_dict)
51
+ la_content_output_dict = {
52
+ "output": zero_la_content,
53
+ }
54
+ elif task == "text_to_audio" or task == "text_to_music":
55
+ content_output_dict = self.text_encoder([content])
56
+ la_content_output_dict = {
57
+ "output": zero_la_content,
58
+ }
59
+ elif task == "speech_to_audio":
60
+ input_dict = {
61
+ "embed": content,
62
+ "embed_len": torch.tensor([content.shape[1]], dtype=torch.int).to(device),
63
+ }
64
+ content_output_dict = self.speech_encoder(input_dict)
65
+ la_content_output_dict = {
66
+ "output": zero_la_content,
67
+ }
68
+ elif task == "direct_speech_to_audio":
69
+ # content shape [1, L/T 133, dim] mask [1, L/T 133] in hubert
70
+ if len(content.shape) < 3:
71
+ content = content.unsqueeze(0)
72
+ mask = torch.ones(content.shape[:2])
73
+ mask = (mask == 1).to(content.device)
74
+ content_output_dict = {
75
+ "output": content,
76
+ "mask": mask,
77
+ }
78
+ la_content_output_dict = {
79
+ "output": zero_la_content,
80
+ }
81
+ elif task == "sketch_to_audio":
82
+ content_output_dict = self.sketch_encoder([content["caption"]])
83
+ content_dict = {
84
+ "f0": torch.as_tensor(content["f0"]),
85
+ "energy": torch.as_tensor(content["energy"]),
86
+ }
87
+ for key in list(content_dict.keys()):
88
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
89
+ device
90
+ )
91
+ la_content_output_dict = self.sketch_encoder.encode_sketch(
92
+ **content_dict
93
+ )
94
+ elif task == "video_to_audio":
95
+ content_dict = {
96
+ "frames": torch.as_tensor(content).float(),
97
+ "frame_nums": torch.as_tensor(content.shape[0]),
98
+ }
99
+ for key in list(content_dict.keys()):
100
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
101
+ device
102
+ )
103
+ content_output_dict = self.video_encoder(**content_dict)
104
+ la_content_output_dict = {
105
+ "output": zero_la_content,
106
+ }
107
+ elif task == "singing_voice_synthesis":
108
+ content_dict = {
109
+ "phoneme":
110
+ torch.as_tensor(content["phoneme"]).long(),
111
+ "midi":
112
+ torch.as_tensor(content["midi"]).long(),
113
+ "midi_duration":
114
+ torch.as_tensor(content["midi_duration"]).float(),
115
+ "is_slur":
116
+ torch.as_tensor(content["is_slur"]).long()
117
+ }
118
+ if "spk" in content:
119
+ if self.midi_encoder.spk_config.encoding_format == "id":
120
+ content_dict["spk"] = torch.as_tensor(content["spk"]
121
+ ).long()
122
+ elif self.midi_encoder.spk_config.encoding_format == "embedding":
123
+ content_dict["spk"] = torch.as_tensor(content["spk"]
124
+ ).float()
125
+ for key in list(content_dict.keys()):
126
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
127
+ device
128
+ )
129
+ content_dict["lengths"] = torch.as_tensor([
130
+ len(content["phoneme"])
131
+ ])
132
+ content_output_dict = self.midi_encoder(**content_dict)
133
+ la_content_output_dict = {"output": zero_la_content}
134
+ elif task == "text_to_speech":
135
+ content_dict = {
136
+ "phoneme": torch.as_tensor(content["phoneme"]).long(),
137
+ }
138
+ if "spk" in content:
139
+ if self.phoneme_encoder.spk_config.encoding_format == "id":
140
+ content_dict["spk"] = torch.as_tensor(content["spk"]
141
+ ).long()
142
+ elif self.phoneme_encoder.spk_config.encoding_format == "embedding":
143
+ content_dict["spk"] = torch.as_tensor(content["spk"]
144
+ ).float()
145
+ for key in list(content_dict.keys()):
146
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
147
+ device
148
+ )
149
+ content_dict["lengths"] = torch.as_tensor([
150
+ len(content["phoneme"])
151
+ ])
152
+ content_output_dict = self.phoneme_encoder(**content_dict)
153
+ la_content_output_dict = {"output": zero_la_content}
154
+ elif task == "singing_acoustic_modeling":
155
+ content_dict = {
156
+ "phoneme": torch.as_tensor(content["phoneme"]).long(),
157
+ }
158
+ for key in list(content_dict.keys()):
159
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
160
+ device
161
+ )
162
+ content_dict["lengths"] = torch.as_tensor([
163
+ len(content["phoneme"])
164
+ ])
165
+ content_output_dict = self.pitch_encoder(**content_dict)
166
+
167
+ content_dict = {
168
+ "f0": torch.as_tensor(content["f0"]),
169
+ "uv": torch.as_tensor(content["uv"]),
170
+ }
171
+ for key in list(content_dict.keys()):
172
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
173
+ device
174
+ )
175
+ la_content_output_dict = self.pitch_encoder.encode_pitch(
176
+ **content_dict
177
+ )
178
+
179
+ batch_content_output.append(content_output_dict["output"][0])
180
+ batch_content_mask.append(content_output_dict["mask"][0])
181
+ batch_la_content_output.append(la_content_output_dict["output"][0])
182
+
183
+ batch_content_output = nn.utils.rnn.pad_sequence(
184
+ batch_content_output, batch_first=True, padding_value=0
185
+ )
186
+ batch_content_mask = nn.utils.rnn.pad_sequence(
187
+ batch_content_mask, batch_first=True, padding_value=False
188
+ )
189
+ batch_la_content_output = nn.utils.rnn.pad_sequence(
190
+ batch_la_content_output, batch_first=True, padding_value=0
191
+ )
192
+ return {
193
+ "content": batch_content_output,
194
+ "content_mask": batch_content_mask,
195
+ "length_aligned_content": batch_la_content_output,
196
+ }
197
+
198
+
199
+ class BatchedContentEncoder(ContentEncoder):
200
+ def encode_content(
201
+ self, batch_content: list | dict, batch_task: list[str],
202
+ device: str | torch.device
203
+ ):
204
+ task = batch_task[0]
205
+ zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
206
+ if task == "audio_super_resolution" or task == "speech_enhancement":
207
+ content_dict = {
208
+ "waveform":
209
+ batch_content["content"].unsqueeze(1).float().to(device),
210
+ "waveform_lengths":
211
+ batch_content["content_lengths"].long().to(device),
212
+ }
213
+ content_output = self.audio_encoder(**content_dict)
214
+ la_content_output = zero_la_content
215
+ elif task == "text_to_audio":
216
+ content_output = self.text_encoder(batch_content)
217
+ la_content_output = zero_la_content
218
+ elif task == "video_to_audio":
219
+ content_dict = {
220
+ "frames":
221
+ batch_content["content"].float().to(device),
222
+ "frame_nums":
223
+ batch_content["content_lengths"].long().to(device),
224
+ }
225
+ content_output = self.video_encoder(**content_dict)
226
+ la_content_output = zero_la_content
227
+ elif task == "singing_voice_synthesis":
228
+ content_dict = {
229
+ "phoneme":
230
+ batch_content["phoneme"].long().to(device),
231
+ "midi":
232
+ batch_content["midi"].long().to(device),
233
+ "midi_duration":
234
+ batch_content["midi_duration"].float().to(device),
235
+ "is_slur":
236
+ batch_content["is_slur"].long().to(device),
237
+ "lengths":
238
+ batch_content["phoneme_lengths"].long().cpu(),
239
+ }
240
+ if "spk" in batch_content:
241
+ if self.midi_encoder.spk_config.encoding_format == "id":
242
+ content_dict["spk"] = batch_content["spk"].long(
243
+ ).to(device)
244
+ elif self.midi_encoder.spk_config.encoding_format == "embedding":
245
+ content_dict["spk"] = batch_content["spk"].float(
246
+ ).to(device)
247
+ content_output = self.midi_encoder(**content_dict)
248
+ la_content_output = zero_la_content
249
+ elif task == "text_to_speech":
250
+ content_dict = {
251
+ "phoneme": batch_content["phoneme"].long().to(device),
252
+ "lengths": batch_content["phoneme_lengths"].long().cpu(),
253
+ }
254
+ if "spk" in batch_content:
255
+ if self.phoneme_encoder.spk_config.encoding_format == "id":
256
+ content_dict["spk"] = batch_content["spk"].long(
257
+ ).to(device)
258
+ elif self.phoneme_encoder.spk_config.encoding_format == "embedding":
259
+ content_dict["spk"] = batch_content["spk"].float(
260
+ ).to(device)
261
+ content_output = self.phoneme_encoder(**content_dict)
262
+ la_content_output = zero_la_content
263
+ elif task == "singing_acoustic_modeling":
264
+ content_dict = {
265
+ "phoneme": batch_content["phoneme"].long().to(device),
266
+ "lengths": batch_content["phoneme_lengths"].long().to(device),
267
+ }
268
+ content_output = self.pitch_encoder(**content_dict)
269
+
270
+ content_dict = {
271
+ "f0": batch_content["f0"].float().to(device),
272
+ "uv": batch_content["uv"].float().to(device),
273
+ }
274
+ la_content_output = self.pitch_encoder.encode_pitch(**content_dict)
275
+
276
+ return {
277
+ "content": content_output["output"],
278
+ "content_mask": content_output["mask"],
279
+ "length_aligned_content": la_content_output,
280
+ }
models/content_encoder/midi_encoder.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import Parameter
8
+
9
+ from utils.torch_utilities import create_mask_from_length
10
+ from utils.diffsinger_utilities import denorm_f0, f0_to_coarse
11
+
12
+
13
+ def make_positions(tensor, padding_idx):
14
+ """Replace non-padding symbols with their position numbers.
15
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
16
+ """
17
+ # The series of casts and type-conversions here are carefully
18
+ # balanced to both work with ONNX export and XLA. In particular XLA
19
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
20
+ # how to handle the dtype kwarg in cumsum.
21
+ mask = tensor.ne(padding_idx).int()
22
+ return (torch.cumsum(mask, dim=1).type_as(mask) *
23
+ mask).long() + padding_idx
24
+
25
+
26
+ def softmax(x, dim):
27
+ return F.softmax(x, dim=dim, dtype=torch.float32)
28
+
29
+
30
+ def LayerNorm(
31
+ normalized_shape, eps=1e-5, elementwise_affine=True, export=False
32
+ ):
33
+ if not export and torch.cuda.is_available():
34
+ try:
35
+ from apex.normalization import FusedLayerNorm
36
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
37
+ except ImportError:
38
+ pass
39
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
40
+
41
+
42
+ def Linear(in_features, out_features, bias=True):
43
+ m = nn.Linear(in_features, out_features, bias)
44
+ nn.init.xavier_uniform_(m.weight)
45
+ if bias:
46
+ nn.init.constant_(m.bias, 0.)
47
+ return m
48
+
49
+
50
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None):
51
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
52
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
53
+ if padding_idx is not None:
54
+ nn.init.constant_(m.weight[padding_idx], 0)
55
+ return m
56
+
57
+
58
+ class BatchNorm1dTBC(nn.Module):
59
+ def __init__(self, c):
60
+ super(BatchNorm1dTBC, self).__init__()
61
+ self.bn = nn.BatchNorm1d(c)
62
+
63
+ def forward(self, x):
64
+ """
65
+
66
+ :param x: [T, B, C]
67
+ :return: [T, B, C]
68
+ """
69
+ x = x.permute(1, 2, 0) # [B, C, T]
70
+ x = self.bn(x) # [B, C, T]
71
+ x = x.permute(2, 0, 1) # [T, B, C]
72
+ return x
73
+
74
+
75
+ class PositionalEncoding(nn.Module):
76
+ """Positional encoding.
77
+ Args:
78
+ d_model (int): Embedding dimension.
79
+ dropout_rate (float): Dropout rate.
80
+ max_len (int): Maximum input length.
81
+ reverse (bool): Whether to reverse the input position.
82
+ """
83
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
84
+ """Construct an PositionalEncoding object."""
85
+ super(PositionalEncoding, self).__init__()
86
+ self.d_model = d_model
87
+ self.reverse = reverse
88
+ self.xscale = math.sqrt(self.d_model)
89
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
90
+ self.pe = None
91
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
92
+
93
+ def extend_pe(self, x):
94
+ """Reset the positional encodings."""
95
+ if self.pe is not None:
96
+ if self.pe.size(1) >= x.size(1):
97
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
98
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
99
+ return
100
+ pe = torch.zeros(x.size(1), self.d_model)
101
+ if self.reverse:
102
+ position = torch.arange(
103
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
104
+ ).unsqueeze(1)
105
+ else:
106
+ position = torch.arange(0, x.size(1),
107
+ dtype=torch.float32).unsqueeze(1)
108
+ div_term = torch.exp(
109
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
110
+ -(math.log(10000.0) / self.d_model)
111
+ )
112
+ pe[:, 0::2] = torch.sin(position * div_term)
113
+ pe[:, 1::2] = torch.cos(position * div_term)
114
+ pe = pe.unsqueeze(0)
115
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
116
+
117
+ def forward(self, x: torch.Tensor):
118
+ """Add positional encoding.
119
+ Args:
120
+ x (torch.Tensor): Input tensor (batch, time, `*`).
121
+ Returns:
122
+ torch.Tensor: Encoded tensor (batch, time, `*`).
123
+ """
124
+ self.extend_pe(x)
125
+ x = x * self.xscale + self.pe[:, :x.size(1)]
126
+ return self.dropout(x)
127
+
128
+
129
+ class SinusoidalPositionalEmbedding(nn.Module):
130
+ """This module produces sinusoidal positional embeddings of any length.
131
+
132
+ Padding symbols are ignored.
133
+ """
134
+ def __init__(self, d_model, padding_idx, init_size=2048):
135
+ super().__init__()
136
+ self.d_model = d_model
137
+ self.padding_idx = padding_idx
138
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
139
+ init_size,
140
+ d_model,
141
+ padding_idx,
142
+ )
143
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
144
+
145
+ @staticmethod
146
+ def get_embedding(num_embeddings, d_model, padding_idx=None):
147
+ """Build sinusoidal embeddings.
148
+
149
+ This matches the implementation in tensor2tensor, but differs slightly
150
+ from the description in Section 3.5 of "Attention Is All You Need".
151
+ """
152
+ half_dim = d_model // 2
153
+ emb = math.log(10000) / (half_dim - 1)
154
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
155
+ emb = torch.arange(num_embeddings,
156
+ dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
157
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)],
158
+ dim=1).view(num_embeddings, -1)
159
+ if d_model % 2 == 1:
160
+ # zero pad
161
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
162
+ if padding_idx is not None:
163
+ emb[padding_idx, :] = 0
164
+ return emb
165
+
166
+ def forward(
167
+ self,
168
+ x,
169
+ lengths,
170
+ incremental_state=None,
171
+ timestep=None,
172
+ positions=None,
173
+ **kwargs
174
+ ):
175
+ """Input is expected to be of size [bsz x seqlen]."""
176
+ bsz, seq_len = x.shape[:2]
177
+ max_pos = self.padding_idx + 1 + seq_len
178
+ if self.weights is None or max_pos > self.weights.size(0):
179
+ # recompute/expand embeddings if needed
180
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
181
+ max_pos,
182
+ self.d_model,
183
+ self.padding_idx,
184
+ )
185
+ self.weights = self.weights.to(self._float_tensor)
186
+
187
+ if incremental_state is not None:
188
+ # positions is the same for every token when decoding a single step
189
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
190
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
191
+
192
+ positions = create_mask_from_length(
193
+ lengths, max_length=x.shape[1]
194
+ ) * (torch.arange(x.shape[1]) + 1).unsqueeze(0).expand(x.shape[0], -1)
195
+ positions = positions.to(self.weights.device)
196
+ pos_emb = self.weights.index_select(0, positions.view(-1)).view(
197
+ bsz, seq_len, -1
198
+ ).detach()
199
+ return x + pos_emb
200
+
201
+ def max_positions(self):
202
+ """Maximum number of supported positions."""
203
+ return int(1e5) # an arbitrary large number
204
+
205
+
206
+ class RelPositionalEncoding(PositionalEncoding):
207
+ """Relative positional encoding module.
208
+ See : Appendix B in https://arxiv.org/abs/1901.02860
209
+ Args:
210
+ d_model (int): Embedding dimension.
211
+ dropout_rate (float): Dropout rate.
212
+ max_len (int): Maximum input length.
213
+ """
214
+ def __init__(self, d_model, dropout_rate, max_len=5000):
215
+ """Initialize class."""
216
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
217
+
218
+ def forward(self, x, lengths):
219
+ """Compute positional encoding.
220
+ Args:
221
+ x (torch.Tensor): Input tensor (batch, time, `*`).
222
+ Returns:
223
+ torch.Tensor: Encoded tensor (batch, time, `*`).
224
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
225
+ """
226
+ self.extend_pe(x)
227
+ x = x * self.xscale
228
+ pos_emb = self.pe[:, :x.size(1)]
229
+ return self.dropout(x) + self.dropout(pos_emb)
230
+
231
+
232
+ class MultiheadAttention(nn.Module):
233
+ def __init__(
234
+ self,
235
+ embed_dim,
236
+ num_heads,
237
+ kdim=None,
238
+ vdim=None,
239
+ dropout=0.,
240
+ bias=True,
241
+ add_bias_kv=False,
242
+ add_zero_attn=False,
243
+ self_attention=False,
244
+ encoder_decoder_attention=False
245
+ ):
246
+ super().__init__()
247
+ self.embed_dim = embed_dim
248
+ self.kdim = kdim if kdim is not None else embed_dim
249
+ self.vdim = vdim if vdim is not None else embed_dim
250
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
251
+
252
+ self.num_heads = num_heads
253
+ self.dropout = dropout
254
+ self.head_dim = embed_dim // num_heads
255
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
256
+ self.scaling = self.head_dim**-0.5
257
+
258
+ self.self_attention = self_attention
259
+ self.encoder_decoder_attention = encoder_decoder_attention
260
+
261
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
262
+ 'value to be of the same size'
263
+
264
+ if self.qkv_same_dim:
265
+ self.in_proj_weight = Parameter(
266
+ torch.Tensor(3 * embed_dim, embed_dim)
267
+ )
268
+ else:
269
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
270
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
271
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
272
+
273
+ if bias:
274
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
275
+ else:
276
+ self.register_parameter('in_proj_bias', None)
277
+
278
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
279
+
280
+ if add_bias_kv:
281
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
282
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
283
+ else:
284
+ self.bias_k = self.bias_v = None
285
+
286
+ self.add_zero_attn = add_zero_attn
287
+
288
+ self.reset_parameters()
289
+
290
+ self.enable_torch_version = False
291
+ if hasattr(F, "multi_head_attention_forward"):
292
+ self.enable_torch_version = True
293
+ else:
294
+ self.enable_torch_version = False
295
+ self.last_attn_probs = None
296
+
297
+ def reset_parameters(self):
298
+ if self.qkv_same_dim:
299
+ nn.init.xavier_uniform_(self.in_proj_weight)
300
+ else:
301
+ nn.init.xavier_uniform_(self.k_proj_weight)
302
+ nn.init.xavier_uniform_(self.v_proj_weight)
303
+ nn.init.xavier_uniform_(self.q_proj_weight)
304
+
305
+ nn.init.xavier_uniform_(self.out_proj.weight)
306
+ if self.in_proj_bias is not None:
307
+ nn.init.constant_(self.in_proj_bias, 0.)
308
+ nn.init.constant_(self.out_proj.bias, 0.)
309
+ if self.bias_k is not None:
310
+ nn.init.xavier_normal_(self.bias_k)
311
+ if self.bias_v is not None:
312
+ nn.init.xavier_normal_(self.bias_v)
313
+
314
+ def forward(
315
+ self,
316
+ query,
317
+ key,
318
+ value,
319
+ key_padding_mask=None,
320
+ incremental_state=None,
321
+ need_weights=True,
322
+ static_kv=False,
323
+ attn_mask=None,
324
+ before_softmax=False,
325
+ need_head_weights=False,
326
+ enc_dec_attn_constraint_mask=None,
327
+ reset_attn_weight=None
328
+ ):
329
+ """Input shape: Time x Batch x Channel
330
+
331
+ Args:
332
+ key_padding_mask (ByteTensor, optional): mask to exclude
333
+ keys that are pads, of shape `(batch, src_len)`, where
334
+ padding elements are indicated by 1s.
335
+ need_weights (bool, optional): return the attention weights,
336
+ averaged over heads (default: False).
337
+ attn_mask (ByteTensor, optional): typically used to
338
+ implement causal attention, where the mask prevents the
339
+ attention from looking forward in time (default: None).
340
+ before_softmax (bool, optional): return the raw attention
341
+ weights and values before the attention softmax.
342
+ need_head_weights (bool, optional): return the attention
343
+ weights for each head. Implies *need_weights*. Default:
344
+ return the average attention weights over all heads.
345
+ """
346
+ if need_head_weights:
347
+ need_weights = True
348
+
349
+ tgt_len, bsz, embed_dim = query.size()
350
+ assert embed_dim == self.embed_dim
351
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
352
+
353
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
354
+ if self.qkv_same_dim:
355
+ return F.multi_head_attention_forward(
356
+ query, key, value, self.embed_dim, self.num_heads,
357
+ self.in_proj_weight, self.in_proj_bias, self.bias_k,
358
+ self.bias_v, self.add_zero_attn, self.dropout,
359
+ self.out_proj.weight, self.out_proj.bias, self.training,
360
+ key_padding_mask, need_weights, attn_mask
361
+ )
362
+ else:
363
+ return F.multi_head_attention_forward(
364
+ query,
365
+ key,
366
+ value,
367
+ self.embed_dim,
368
+ self.num_heads,
369
+ torch.empty([0]),
370
+ self.in_proj_bias,
371
+ self.bias_k,
372
+ self.bias_v,
373
+ self.add_zero_attn,
374
+ self.dropout,
375
+ self.out_proj.weight,
376
+ self.out_proj.bias,
377
+ self.training,
378
+ key_padding_mask,
379
+ need_weights,
380
+ attn_mask,
381
+ use_separate_proj_weight=True,
382
+ q_proj_weight=self.q_proj_weight,
383
+ k_proj_weight=self.k_proj_weight,
384
+ v_proj_weight=self.v_proj_weight
385
+ )
386
+
387
+ if incremental_state is not None:
388
+ print('Not implemented error.')
389
+ exit()
390
+ else:
391
+ saved_state = None
392
+
393
+ if self.self_attention:
394
+ # self-attention
395
+ q, k, v = self.in_proj_qkv(query)
396
+ elif self.encoder_decoder_attention:
397
+ # encoder-decoder attention
398
+ q = self.in_proj_q(query)
399
+ if key is None:
400
+ assert value is None
401
+ k = v = None
402
+ else:
403
+ k = self.in_proj_k(key)
404
+ v = self.in_proj_v(key)
405
+
406
+ else:
407
+ q = self.in_proj_q(query)
408
+ k = self.in_proj_k(key)
409
+ v = self.in_proj_v(value)
410
+ q *= self.scaling
411
+
412
+ if self.bias_k is not None:
413
+ assert self.bias_v is not None
414
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
415
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
416
+ if attn_mask is not None:
417
+ attn_mask = torch.cat(
418
+ [attn_mask,
419
+ attn_mask.new_zeros(attn_mask.size(0), 1)],
420
+ dim=1
421
+ )
422
+ if key_padding_mask is not None:
423
+ key_padding_mask = torch.cat(
424
+ [
425
+ key_padding_mask,
426
+ key_padding_mask.new_zeros(
427
+ key_padding_mask.size(0), 1
428
+ )
429
+ ],
430
+ dim=1
431
+ )
432
+
433
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads,
434
+ self.head_dim).transpose(0, 1)
435
+ if k is not None:
436
+ k = k.contiguous().view(-1, bsz * self.num_heads,
437
+ self.head_dim).transpose(0, 1)
438
+ if v is not None:
439
+ v = v.contiguous().view(-1, bsz * self.num_heads,
440
+ self.head_dim).transpose(0, 1)
441
+
442
+ if saved_state is not None:
443
+ print('Not implemented error.')
444
+ exit()
445
+
446
+ src_len = k.size(1)
447
+
448
+ # This is part of a workaround to get around fork/join parallelism
449
+ # not supporting Optional types.
450
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
451
+ []
452
+ ):
453
+ key_padding_mask = None
454
+
455
+ if key_padding_mask is not None:
456
+ assert key_padding_mask.size(0) == bsz
457
+ assert key_padding_mask.size(1) == src_len
458
+
459
+ if self.add_zero_attn:
460
+ src_len += 1
461
+ k = torch.cat(
462
+ [k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1
463
+ )
464
+ v = torch.cat(
465
+ [v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1
466
+ )
467
+ if attn_mask is not None:
468
+ attn_mask = torch.cat(
469
+ [attn_mask,
470
+ attn_mask.new_zeros(attn_mask.size(0), 1)],
471
+ dim=1
472
+ )
473
+ if key_padding_mask is not None:
474
+ key_padding_mask = torch.cat(
475
+ [
476
+ key_padding_mask,
477
+ torch.zeros(key_padding_mask.size(0),
478
+ 1).type_as(key_padding_mask)
479
+ ],
480
+ dim=1
481
+ )
482
+
483
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
484
+ attn_weights = self.apply_sparse_mask(
485
+ attn_weights, tgt_len, src_len, bsz
486
+ )
487
+
488
+ assert list(attn_weights.size()) == [
489
+ bsz * self.num_heads, tgt_len, src_len
490
+ ]
491
+
492
+ if attn_mask is not None:
493
+ if len(attn_mask.shape) == 2:
494
+ attn_mask = attn_mask.unsqueeze(0)
495
+ elif len(attn_mask.shape) == 3:
496
+ attn_mask = attn_mask[:, None].repeat(
497
+ [1, self.num_heads, 1, 1]
498
+ ).reshape(bsz * self.num_heads, tgt_len, src_len)
499
+ attn_weights = attn_weights + attn_mask
500
+
501
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
502
+ attn_weights = attn_weights.view(
503
+ bsz, self.num_heads, tgt_len, src_len
504
+ )
505
+ attn_weights = attn_weights.masked_fill(
506
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
507
+ -1e9,
508
+ )
509
+ attn_weights = attn_weights.view(
510
+ bsz * self.num_heads, tgt_len, src_len
511
+ )
512
+
513
+ if key_padding_mask is not None:
514
+ # don't attend to padding symbols
515
+ attn_weights = attn_weights.view(
516
+ bsz, self.num_heads, tgt_len, src_len
517
+ )
518
+ attn_weights = attn_weights.masked_fill(
519
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
520
+ -1e9,
521
+ )
522
+ attn_weights = attn_weights.view(
523
+ bsz * self.num_heads, tgt_len, src_len
524
+ )
525
+
526
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
527
+
528
+ if before_softmax:
529
+ return attn_weights, v
530
+
531
+ attn_weights_float = softmax(attn_weights, dim=-1)
532
+ attn_weights = attn_weights_float.type_as(attn_weights)
533
+ attn_probs = F.dropout(
534
+ attn_weights_float.type_as(attn_weights),
535
+ p=self.dropout,
536
+ training=self.training
537
+ )
538
+
539
+ if reset_attn_weight is not None:
540
+ if reset_attn_weight:
541
+ self.last_attn_probs = attn_probs.detach()
542
+ else:
543
+ assert self.last_attn_probs is not None
544
+ attn_probs = self.last_attn_probs
545
+ attn = torch.bmm(attn_probs, v)
546
+ assert list(attn.size()) == [
547
+ bsz * self.num_heads, tgt_len, self.head_dim
548
+ ]
549
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
550
+ attn = self.out_proj(attn)
551
+
552
+ if need_weights:
553
+ attn_weights = attn_weights_float.view(
554
+ bsz, self.num_heads, tgt_len, src_len
555
+ ).transpose(1, 0)
556
+ if not need_head_weights:
557
+ # average attention weights over heads
558
+ attn_weights = attn_weights.mean(dim=0)
559
+ else:
560
+ attn_weights = None
561
+
562
+ return attn, (attn_weights, attn_logits)
563
+
564
+ def in_proj_qkv(self, query):
565
+ return self._in_proj(query).chunk(3, dim=-1)
566
+
567
+ def in_proj_q(self, query):
568
+ if self.qkv_same_dim:
569
+ return self._in_proj(query, end=self.embed_dim)
570
+ else:
571
+ bias = self.in_proj_bias
572
+ if bias is not None:
573
+ bias = bias[:self.embed_dim]
574
+ return F.linear(query, self.q_proj_weight, bias)
575
+
576
+ def in_proj_k(self, key):
577
+ if self.qkv_same_dim:
578
+ return self._in_proj(
579
+ key, start=self.embed_dim, end=2 * self.embed_dim
580
+ )
581
+ else:
582
+ weight = self.k_proj_weight
583
+ bias = self.in_proj_bias
584
+ if bias is not None:
585
+ bias = bias[self.embed_dim:2 * self.embed_dim]
586
+ return F.linear(key, weight, bias)
587
+
588
+ def in_proj_v(self, value):
589
+ if self.qkv_same_dim:
590
+ return self._in_proj(value, start=2 * self.embed_dim)
591
+ else:
592
+ weight = self.v_proj_weight
593
+ bias = self.in_proj_bias
594
+ if bias is not None:
595
+ bias = bias[2 * self.embed_dim:]
596
+ return F.linear(value, weight, bias)
597
+
598
+ def _in_proj(self, input, start=0, end=None):
599
+ weight = self.in_proj_weight
600
+ bias = self.in_proj_bias
601
+ weight = weight[start:end, :]
602
+ if bias is not None:
603
+ bias = bias[start:end]
604
+ return F.linear(input, weight, bias)
605
+
606
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
607
+ return attn_weights
608
+
609
+
610
+ class TransformerFFNLayer(nn.Module):
611
+ def __init__(
612
+ self,
613
+ hidden_size,
614
+ filter_size,
615
+ padding="SAME",
616
+ kernel_size=1,
617
+ dropout=0.,
618
+ act='gelu'
619
+ ):
620
+ super().__init__()
621
+ self.kernel_size = kernel_size
622
+ self.dropout = dropout
623
+ self.act = act
624
+ if padding == 'SAME':
625
+ self.ffn_1 = nn.Conv1d(
626
+ hidden_size,
627
+ filter_size,
628
+ kernel_size,
629
+ padding=kernel_size // 2
630
+ )
631
+ elif padding == 'LEFT':
632
+ self.ffn_1 = nn.Sequential(
633
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
634
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
635
+ )
636
+ self.ffn_2 = nn.Linear(filter_size, hidden_size)
637
+
638
+ def forward(
639
+ self,
640
+ x,
641
+ ):
642
+ # x: T x B x C
643
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
644
+ x = x * self.kernel_size**-0.5
645
+
646
+ if self.act == 'gelu':
647
+ x = F.gelu(x)
648
+ if self.act == 'relu':
649
+ x = F.relu(x)
650
+ if self.act == 'swish':
651
+ x = F.silu(x)
652
+ x = F.dropout(x, self.dropout, training=self.training)
653
+ x = self.ffn_2(x)
654
+ return x
655
+
656
+
657
+ class EncoderSelfAttentionLayer(nn.Module):
658
+ def __init__(
659
+ self,
660
+ c,
661
+ num_heads,
662
+ dropout,
663
+ attention_dropout=0.1,
664
+ relu_dropout=0.1,
665
+ kernel_size=9,
666
+ padding='SAME',
667
+ norm='ln',
668
+ act='gelu',
669
+ padding_set_zero=True
670
+ ):
671
+ super().__init__()
672
+ self.c = c
673
+ self.dropout = dropout
674
+ self.num_heads = num_heads
675
+ self.padding_set_zero = padding_set_zero
676
+ if num_heads > 0:
677
+ if norm == 'ln':
678
+ self.layer_norm1 = LayerNorm(c)
679
+ elif norm == 'bn':
680
+ self.layer_norm1 = BatchNorm1dTBC(c)
681
+ self.self_attn = MultiheadAttention(
682
+ self.c,
683
+ num_heads=num_heads,
684
+ self_attention=True,
685
+ dropout=attention_dropout,
686
+ bias=False,
687
+ )
688
+ if norm == 'ln':
689
+ self.layer_norm2 = LayerNorm(c)
690
+ elif norm == 'bn':
691
+ self.layer_norm2 = BatchNorm1dTBC(c)
692
+ self.ffn = TransformerFFNLayer(
693
+ c,
694
+ 4 * c,
695
+ kernel_size=kernel_size,
696
+ dropout=relu_dropout,
697
+ padding=padding,
698
+ act=act
699
+ )
700
+
701
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
702
+ layer_norm_training = kwargs.get('layer_norm_training', None)
703
+ if layer_norm_training is not None:
704
+ self.layer_norm1.training = layer_norm_training
705
+ self.layer_norm2.training = layer_norm_training
706
+ if self.num_heads > 0:
707
+ residual = x
708
+ x = self.layer_norm1(x)
709
+ x, _, = self.self_attn(
710
+ query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
711
+ )
712
+ x = F.dropout(x, self.dropout, training=self.training)
713
+ x = residual + x
714
+ if self.padding_set_zero:
715
+ x = x * (1 - encoder_padding_mask.float()).transpose(0,
716
+ 1)[...,
717
+ None]
718
+
719
+ residual = x
720
+ x = self.layer_norm2(x)
721
+ x = self.ffn(x)
722
+ x = F.dropout(x, self.dropout, training=self.training)
723
+ x = residual + x
724
+ if self.padding_set_zero:
725
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[...,
726
+ None]
727
+ return x
728
+
729
+
730
+ class TransformerEncoderLayer(nn.Module):
731
+ def __init__(
732
+ self,
733
+ hidden_size,
734
+ dropout,
735
+ kernel_size,
736
+ num_heads=2,
737
+ norm='ln',
738
+ padding_set_zero=True,
739
+ ):
740
+ super().__init__()
741
+ self.hidden_size = hidden_size
742
+ self.dropout = dropout
743
+ self.num_heads = num_heads
744
+ self.op = EncoderSelfAttentionLayer(
745
+ hidden_size,
746
+ num_heads,
747
+ dropout=dropout,
748
+ attention_dropout=0.0,
749
+ relu_dropout=dropout,
750
+ kernel_size=kernel_size,
751
+ padding="SAME",
752
+ norm=norm,
753
+ act="gelu",
754
+ padding_set_zero=padding_set_zero
755
+ )
756
+
757
+ def forward(self, x, **kwargs):
758
+ return self.op(x, **kwargs)
759
+
760
+
761
+ class FFTBlocks(nn.Module):
762
+ def __init__(
763
+ self,
764
+ hidden_size,
765
+ num_layers,
766
+ ffn_kernel_size=9,
767
+ dropout=0.1,
768
+ num_heads=2,
769
+ use_last_norm=True,
770
+ padding_set_zero=True,
771
+ ):
772
+ super().__init__()
773
+ self.num_layers = num_layers
774
+ embed_dim = self.hidden_size = hidden_size
775
+ self.dropout = dropout
776
+ self.use_last_norm = use_last_norm
777
+ self.padding_set_zero = padding_set_zero
778
+
779
+ self.layers = nn.ModuleList([])
780
+ self.layers.extend(
781
+ [
782
+ TransformerEncoderLayer(
783
+ self.hidden_size,
784
+ self.dropout,
785
+ kernel_size=ffn_kernel_size,
786
+ num_heads=num_heads,
787
+ padding_set_zero=padding_set_zero,
788
+ ) for _ in range(self.num_layers)
789
+ ]
790
+ )
791
+ if self.use_last_norm:
792
+ self.layer_norm = nn.LayerNorm(embed_dim)
793
+ else:
794
+ self.layer_norm = None
795
+
796
+ def forward(self, x, padding_mask=None, attn_mask=None):
797
+ """
798
+ :param x: [B, T, C]
799
+ :param padding_mask: [B, T]
800
+ :return: [B, T, C] or [L, B, T, C]
801
+ """
802
+ if padding_mask is None:
803
+ padding_mask = torch.zeros(x.size(0), x.size(1)).to(x.device)
804
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float(
805
+ )[:, :, None] # [T, B, 1]
806
+ # B x T x C -> T x B x C
807
+ x = x.transpose(0, 1)
808
+ if self.padding_set_zero:
809
+ x = x * nonpadding_mask_TB
810
+ for layer in self.layers:
811
+ x = layer(
812
+ x, encoder_padding_mask=padding_mask, attn_mask=attn_mask
813
+ )
814
+ if self.padding_set_zero:
815
+ x = x * nonpadding_mask_TB
816
+ if self.use_last_norm:
817
+ x = self.layer_norm(x)
818
+ if self.padding_set_zero:
819
+ x = x * nonpadding_mask_TB
820
+
821
+ x = x.transpose(0, 1) # [B, T, C]
822
+ return x
823
+
824
+
825
+ class FastSpeech2EncoderBase(nn.Module):
826
+ def __init__(
827
+ self,
828
+ d_model: int,
829
+ num_layers: int,
830
+ num_heads: int,
831
+ ffn_kernel_size: int,
832
+ d_out: int,
833
+ dropout: float = 0.1,
834
+ rel_pos: bool = True,
835
+ padding_set_zero: bool = True
836
+ ):
837
+ super().__init__()
838
+ self.rel_pos = rel_pos
839
+
840
+ if self.rel_pos:
841
+ self.pos_encoding = RelPositionalEncoding(
842
+ d_model, dropout_rate=0.0
843
+ )
844
+ else:
845
+ self.pos_encoding = SinusoidalPositionalEmbedding(
846
+ d_model, padding_idx=0
847
+ )
848
+ self.dropout = dropout
849
+ self.embed_scale = math.sqrt(d_model)
850
+
851
+ self.layers = FFTBlocks(
852
+ hidden_size=d_model,
853
+ num_layers=num_layers,
854
+ ffn_kernel_size=ffn_kernel_size,
855
+ dropout=dropout,
856
+ num_heads=num_heads,
857
+ use_last_norm=True,
858
+ padding_set_zero=padding_set_zero
859
+ )
860
+
861
+ self.out_proj = nn.Linear(d_model, d_out)
862
+ self.apply(self.init_weights)
863
+
864
+ def init_weights(self, m):
865
+ if isinstance(m, nn.Linear):
866
+ nn.init.xavier_uniform_(m.weight)
867
+ if m.bias is not None:
868
+ nn.init.constant_(m.bias, 0.)
869
+ elif isinstance(m, nn.Embedding):
870
+ nn.init.normal_(m.weight, mean=0, std=m.embedding_dim**-0.5)
871
+
872
+
873
+ @dataclass
874
+ class SpkConfig:
875
+ encoding_format: str
876
+ num_spk: int | None = None
877
+ spk_embed_dim: int | None = None
878
+
879
+ def __post_init__(self):
880
+ allowed_formats = {"id", "embedding"}
881
+ assert self.encoding_format in allowed_formats, f"mode must be one of {allowed_formats}, got '{self.encoding_format}'"
882
+ if self.encoding_format == "id":
883
+ assert self.num_spk is not None
884
+ if self.encoding_format == "embedding":
885
+ assert self.spk_embed_dim is not None
886
+
887
+
888
+ class FastSpeech2PhonemeEncoder(FastSpeech2EncoderBase):
889
+ def __init__(
890
+ self,
891
+ phone_vocab_size,
892
+ d_model,
893
+ num_layers,
894
+ num_heads,
895
+ ffn_kernel_size,
896
+ d_out,
897
+ dropout=0.1,
898
+ rel_pos=False,
899
+ spk_config: SpkConfig | None = None,
900
+ padding_set_zero: bool = True
901
+ ):
902
+ super().__init__(
903
+ d_model=d_model,
904
+ num_layers=num_layers,
905
+ num_heads=num_heads,
906
+ ffn_kernel_size=ffn_kernel_size,
907
+ d_out=d_out,
908
+ dropout=dropout,
909
+ rel_pos=rel_pos,
910
+ padding_set_zero=padding_set_zero
911
+ )
912
+ self.phone_embed = Embedding(phone_vocab_size, d_model)
913
+ self.spk_config = spk_config
914
+ if spk_config is not None:
915
+ if spk_config.encoding_format == "id":
916
+ self.spk_embed_proj = Embedding(
917
+ spk_config.num_spk + 1, d_model
918
+ )
919
+ elif spk_config.encoding_format == "embedding":
920
+ self.spk_embed_proj = Linear(spk_config.spk_embed_dim, d_model)
921
+
922
+ def forward(
923
+ self, phoneme: torch.Tensor, lengths: Sequence[int], spk: torch.Tensor
924
+ ):
925
+ x = self.embed_scale * self.phone_embed(phoneme)
926
+ x = self.pos_encoding(x, lengths)
927
+ x = F.dropout(x, p=self.dropout, training=self.training)
928
+
929
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
930
+ x = self.layers(x, padding_mask=padding_mask)
931
+
932
+ if self.spk_config is not None:
933
+ spk_embed = self.spk_embed_proj(spk).unsqueeze(1)
934
+ x = x + spk_embed
935
+
936
+ x = self.out_proj(x)
937
+
938
+ return {"output": x, "mask": ~padding_mask}
939
+
940
+
941
+ class FastSpeech2MIDIEncoder(FastSpeech2PhonemeEncoder):
942
+ def __init__(
943
+ self,
944
+ phone_vocab_size: int,
945
+ midi_vocab_size: int,
946
+ slur_vocab_size: int,
947
+ spk_config: SpkConfig | None,
948
+ d_model: int,
949
+ num_layers: int,
950
+ num_heads: int,
951
+ ffn_kernel_size: int,
952
+ d_out: int,
953
+ dropout: float = 0.1,
954
+ rel_pos: bool = True,
955
+ padding_set_zero: bool = True
956
+ ):
957
+ super().__init__(
958
+ phone_vocab_size=phone_vocab_size,
959
+ d_model=d_model,
960
+ num_layers=num_layers,
961
+ num_heads=num_heads,
962
+ ffn_kernel_size=ffn_kernel_size,
963
+ d_out=d_out,
964
+ dropout=dropout,
965
+ rel_pos=rel_pos,
966
+ spk_config=spk_config,
967
+ padding_set_zero=padding_set_zero
968
+ )
969
+ self.midi_embed = Embedding(midi_vocab_size, d_model, padding_idx=0)
970
+ self.midi_dur_embed = Linear(1, d_model)
971
+ self.is_slur_embed = Embedding(slur_vocab_size, d_model)
972
+
973
+ def forward(
974
+ self,
975
+ phoneme: torch.Tensor,
976
+ midi: torch.Tensor,
977
+ midi_duration: torch.Tensor,
978
+ is_slur: torch.Tensor,
979
+ lengths: Sequence[int],
980
+ spk: torch.Tensor | None = None,
981
+ ):
982
+ x = self.embed_scale * self.phone_embed(phoneme)
983
+ midi_embedding = self.midi_embed(midi)
984
+ midi_dur_embedding = self.midi_dur_embed(midi_duration[:, :, None])
985
+ slur_embedding = self.is_slur_embed(is_slur)
986
+
987
+ x = x + midi_embedding + midi_dur_embedding + slur_embedding
988
+ x = self.pos_encoding(x, lengths)
989
+ x = F.dropout(x, p=self.dropout, training=self.training)
990
+
991
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
992
+ x = self.layers(x, padding_mask=padding_mask)
993
+
994
+ if self.spk_config is not None:
995
+ spk_embed = self.spk_embed_proj(spk).unsqueeze(1)
996
+ x = x + spk_embed
997
+
998
+ x = self.out_proj(x)
999
+
1000
+ return {"output": x, "mask": ~padding_mask}
1001
+
1002
+
1003
+ class FastSpeech2PitchEncoder(FastSpeech2EncoderBase):
1004
+ def __init__(
1005
+ self,
1006
+ phone_vocab_size,
1007
+ d_model,
1008
+ num_layers,
1009
+ num_heads,
1010
+ ffn_kernel_size,
1011
+ d_out,
1012
+ dropout=0.1,
1013
+ rel_pos=False,
1014
+ padding_set_zero=True
1015
+ ):
1016
+ super().__init__(
1017
+ d_model=d_model,
1018
+ num_layers=num_layers,
1019
+ num_heads=num_heads,
1020
+ ffn_kernel_size=ffn_kernel_size,
1021
+ d_out=d_out,
1022
+ dropout=dropout,
1023
+ rel_pos=rel_pos,
1024
+ padding_set_zero=padding_set_zero
1025
+ )
1026
+ self.phone_embed = Embedding(phone_vocab_size, d_model)
1027
+ self.pitch_embed = Embedding(300, d_model)
1028
+
1029
+ def forward(self, phoneme: torch.Tensor, lengths: Sequence[int]):
1030
+ x = self.embed_scale * self.phone_embed(phoneme)
1031
+ x = self.pos_encoding(x, lengths)
1032
+ x = F.dropout(x, p=self.dropout, training=self.training)
1033
+
1034
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
1035
+ x = self.layers(x, padding_mask=padding_mask)
1036
+
1037
+ x = self.out_proj(x)
1038
+
1039
+ return {"output": x, "mask": ~padding_mask}
1040
+
1041
+ def encode_pitch(self, f0, uv):
1042
+
1043
+ f0_denorm = denorm_f0(f0, uv)
1044
+ pitch = f0_to_coarse(f0_denorm)
1045
+ pitch_embed = self.pitch_embed(pitch)
1046
+ return {"output": pitch_embed}
models/content_encoder/sketch_encoder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ try:
6
+ import torch_npu
7
+ from torch_npu.contrib import transfer_to_npu
8
+ DEVICE_TYPE = "npu"
9
+ except ModuleNotFoundError:
10
+ DEVICE_TYPE = "cuda"
11
+
12
+ from .text_encoder import T5TextEncoder
13
+
14
+ class SketchT5TextEncoder(T5TextEncoder):
15
+ def __init__(
16
+ self, f0_dim: int , energy_dim: int, latent_dim: int,
17
+ embed_dim: int, model_name: str = "google/flan-t5-large",
18
+ ):
19
+ super().__init__(
20
+ embed_dim = embed_dim,
21
+ model_name = model_name,
22
+ )
23
+ self.f0_proj = nn.Linear(f0_dim, latent_dim)
24
+ self.f0_norm = nn.LayerNorm(f0_dim)
25
+ self.energy_proj = nn.Linear(energy_dim, latent_dim)
26
+
27
+ def encode(
28
+ self,
29
+ text: list[str],
30
+ ):
31
+ with torch.no_grad(), torch.amp.autocast(
32
+ device_type=DEVICE_TYPE, enabled=False
33
+ ):
34
+ return super().encode(text)
35
+
36
+ def encode_sketch(
37
+ self,
38
+ f0,
39
+ energy,
40
+ ):
41
+ f0_embed = self.f0_proj(self.f0_norm(f0)).unsqueeze(-1)
42
+ energy_embed = self.energy_proj(energy).unsqueeze(-1)
43
+ sketch_embed = torch.cat([f0_embed, energy_embed], dim=-1)
44
+ return {"output": sketch_embed}
45
+
46
+
47
+ if __name__ == "__main__":
48
+ text_encoder = T5TextEncoder(embed_dim=512)
49
+ text = ["a man is speaking", "a woman is singing while a dog is barking"]
50
+
51
+ output = text_encoder(text)
models/content_encoder/star_encoder/__pycache__/Qformer.cpython-310.pyc ADDED
Binary file (30.3 kB). View file
 
models/content_encoder/star_encoder/__pycache__/star_encoder.cpython-310.pyc ADDED
Binary file (4.08 kB). View file
 
models/content_encoder/star_encoder/star_encoder.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import os
5
+ import sys
6
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
7
+ from Qformer import BertConfig, BertLMHeadModel
8
+
9
+
10
+ try:
11
+ import torch_npu
12
+ from torch_npu.contrib import transfer_to_npu
13
+ DEVICE_TYPE = "npu"
14
+ except ModuleNotFoundError:
15
+ DEVICE_TYPE = "cuda"
16
+
17
+ def generate_length_mask(lens, max_length=None):
18
+ lens = torch.as_tensor(lens)
19
+ N = lens.size(0)
20
+ if max_length is None:
21
+ max_length = max(lens)
22
+ idxs = torch.arange(max_length).repeat(N).view(N, max_length)
23
+ idxs = idxs.to(lens.device)
24
+ mask = (idxs < lens.view(-1, 1)).int()
25
+ return mask
26
+
27
+ class QformerBridgeNet(torch.nn.Module):
28
+ def __init__(self, Qformer_model_name: str = "bert-base-uncased", num_query_token: int = 32,
29
+ hiddin_size: int = 1024, speech_width: int = 1024, freeze_QFormer: bool = True,
30
+ load_from_pretrained: str = None):
31
+ super().__init__()
32
+
33
+ self.Qformer_model_name = Qformer_model_name
34
+ self.audio_Qformer, self.audio_query_tokens, encoder_config = self.init_Qformer(num_query_token=num_query_token, speech_width=speech_width)
35
+ self.audio_Qformer.cls = None
36
+ self.audio_Qformer.bert.embeddings.word_embeddings = None
37
+ self.audio_Qformer.bert.embeddings.position_embeddings = None
38
+ for layer in self.audio_Qformer.bert.encoder.layer:
39
+ layer.output = None
40
+ layer.intermediate = None
41
+
42
+ self.freeze_QFormer = freeze_QFormer
43
+ if freeze_QFormer:
44
+ for name, param in self.audio_Qformer.named_parameters():
45
+ param.requires_grad = False
46
+ self.audio_Qformer.eval()
47
+ self.audio_query_tokens.requires_grad = False
48
+
49
+ self.hiddin_projection = torch.nn.Linear(encoder_config.hidden_size, hiddin_size)
50
+ #torch.nn.init.xavier_uniform_(self.hiddin_projection.weight, gain=torch.nn.init.calculate_gain("relu"))
51
+
52
+ if load_from_pretrained:
53
+ state_dict = torch.load(load_from_pretrained)
54
+ del_key = ["projection.weight", "projection.bias"]
55
+ del_state_dict = {k:v for k, v in state_dict.items() if k not in del_key}
56
+ self.load_state_dict(del_state_dict)
57
+ print("Load adaptor_model_pt from", load_from_pretrained)
58
+
59
+
60
+ def init_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, cross_attention_freq=2):
61
+ encoder_config = BertConfig.from_pretrained(self.Qformer_model_name)
62
+ encoder_config.num_hidden_layers = num_hidden_layers
63
+ encoder_config.encoder_width = speech_width
64
+ # insert cross-attention layer every other block
65
+ encoder_config.add_cross_attention = True
66
+ encoder_config.cross_attention_freq = cross_attention_freq
67
+ encoder_config.query_length = num_query_token
68
+ Qformer = BertLMHeadModel(config=encoder_config)
69
+ query_tokens = nn.Parameter(
70
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
71
+ )
72
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
73
+ return Qformer, query_tokens, encoder_config
74
+
75
+ def hidden(self, batch,):
76
+ audio_feature, lens = batch['embed'], batch['embed_len']
77
+ frame_atts = generate_length_mask(lens).to(audio_feature.device)
78
+ audio_query_tokens=self.audio_query_tokens.expand(audio_feature.shape[0], -1, -1)
79
+ #frame_atts = torch.ones(audio_feature.size()[:-1], dtype=torch.long).to(audio_feature.device)
80
+
81
+ #print(audio_query_tokens.shape, audio_feature.shape, frame_atts.shape)
82
+ audio_query_output=self.audio_Qformer.bert(
83
+ query_embeds=audio_query_tokens, #[32,768]
84
+ encoder_hidden_states=audio_feature,
85
+ encoder_attention_mask=frame_atts,
86
+ return_dict=True,
87
+ )
88
+ audio_hidden = audio_query_output.last_hidden_state
89
+ return audio_hidden
90
+
91
+ def forward(self, batch) -> torch.Tensor:
92
+ with torch.no_grad(), torch.amp.autocast(
93
+ device_type=DEVICE_TYPE, enabled=False
94
+ ):
95
+ x = self.hidden(batch)
96
+ x = self.hiddin_projection(x)
97
+
98
+ mask = torch.ones(x.shape[:2])
99
+ mask = (mask == 1).to(x.device)
100
+ return {"output": x, "mask": mask}
101
+
102
+
103
+ if __name__ == '__main__':
104
+ text_encoder = T5TextEncoder()
105
+ text = ["a man is speaking", "a woman is singing while a dog is barking"]
106
+ text_encoder.eval()
107
+ with torch.no_grad():
108
+ output = text_encoder(text)
models/content_encoder/text_encoder.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+
6
+ try:
7
+ import torch_npu
8
+ from torch_npu.contrib import transfer_to_npu
9
+ DEVICE_TYPE = "npu"
10
+ except ModuleNotFoundError:
11
+ DEVICE_TYPE = "cuda"
12
+
13
+
14
+ class TransformersTextEncoderBase(nn.Module):
15
+ def __init__(self, model_name: str, embed_dim: int):
16
+ super().__init__()
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ self.model = AutoModel.from_pretrained(model_name)
19
+ self.proj = nn.Linear(self.model.config.hidden_size, embed_dim)
20
+
21
+ def forward(
22
+ self,
23
+ text: list[str],
24
+ ):
25
+ output, mask = self.encode(text)
26
+ output = self.projection(output)
27
+ return {"output": output, "mask": mask}
28
+
29
+ def encode(self, text: list[str]):
30
+ device = self.model.device
31
+ batch = self.tokenizer(
32
+ text,
33
+ max_length=self.tokenizer.model_max_length,
34
+ padding=True,
35
+ truncation=True,
36
+ return_tensors="pt",
37
+ )
38
+ input_ids = batch.input_ids.to(device)
39
+ attention_mask = batch.attention_mask.to(device)
40
+ output: BaseModelOutput = self.model(
41
+ input_ids=input_ids, attention_mask=attention_mask
42
+ )
43
+ output = output.last_hidden_state
44
+ mask = (attention_mask == 1).to(device)
45
+ return output, mask
46
+
47
+ def projection(self, x):
48
+ return self.proj(x)
49
+
50
+
51
+ class T5TextEncoder(TransformersTextEncoderBase):
52
+ def __init__(
53
+ self, embed_dim: int, model_name: str = "google/flan-t5-large"
54
+ ):
55
+ nn.Module.__init__(self)
56
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
57
+ self.model = T5EncoderModel.from_pretrained(model_name)
58
+ for param in self.model.parameters():
59
+ param.requires_grad = False
60
+ self.model.eval()
61
+ self.proj = nn.Linear(self.model.config.hidden_size, embed_dim)
62
+
63
+ def encode(
64
+ self,
65
+ text: list[str],
66
+ ):
67
+ with torch.no_grad(), torch.amp.autocast(
68
+ device_type=DEVICE_TYPE, enabled=False
69
+ ):
70
+ return super().encode(text)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ text_encoder = T5TextEncoder(embed_dim=512)
75
+ text = ["a man is speaking", "a woman is singing while a dog is barking"]
76
+
77
+ output = text_encoder(text)
models/content_encoder/vision_encoder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from utils.torch_utilities import create_mask_from_length
8
+
9
+
10
+ class MlpVideoEncoder(nn.Module):
11
+ def __init__(
12
+ self,
13
+ video_feat_dim: int,
14
+ embed_dim: int,
15
+ ):
16
+ super().__init__()
17
+ self.mlp = nn.Linear(video_feat_dim, embed_dim)
18
+ self.init_weights()
19
+
20
+ def init_weights(self):
21
+ def _init_weights(module):
22
+ if isinstance(module, nn.Linear):
23
+ nn.init.xavier_uniform_(module.weight)
24
+ if module.bias is not None:
25
+ nn.init.constant_(module.bias, 0.)
26
+
27
+ self.apply(_init_weights)
28
+
29
+ def forward(self, frames: torch.Tensor, frame_nums: Sequence[int]):
30
+ device = frames.device
31
+ x = F.normalize(frames, p=2, dim=-1)
32
+ x = self.mlp(x)
33
+ mask = create_mask_from_length(frame_nums).to(device)
34
+ return {"output": x, "mask": mask}
models/diffsinger_net.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Mish(nn.Module):
8
+ def forward(self, x):
9
+ return x * torch.tanh(F.softplus(x))
10
+
11
+
12
+ class SinusoidalPosEmb(nn.Module):
13
+ def __init__(self, dim):
14
+ super(SinusoidalPosEmb, self).__init__()
15
+ self.dim = dim
16
+
17
+ def forward(self, x):
18
+ device = x.device
19
+ half_dim = self.dim // 2
20
+ emb = math.log(10000) / (half_dim-1)
21
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
22
+ emb = x.unsqueeze(1) * emb.unsqueeze(0)
23
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
24
+ return emb
25
+
26
+
27
+ class ResidualBlock(nn.Module):
28
+ def __init__(self, encoder_hidden, residual_channels, dilation):
29
+ super().__init__()
30
+ self.dilated_conv = nn.Conv1d(
31
+ residual_channels,
32
+ 2 * residual_channels,
33
+ 3,
34
+ padding=dilation,
35
+ dilation=dilation
36
+ )
37
+ self.diffusion_projection = nn.Linear(
38
+ residual_channels, residual_channels
39
+ )
40
+ self.conditioner_projection = nn.Conv1d(
41
+ encoder_hidden, 2 * residual_channels, 1
42
+ )
43
+ self.output_projection = nn.Conv1d(
44
+ residual_channels, 2 * residual_channels, 1
45
+ )
46
+
47
+ def forward(self, x, conditioner, diffusion_step):
48
+ diffusion_step = self.diffusion_projection(diffusion_step
49
+ ).unsqueeze(-1)
50
+ conditioner = self.conditioner_projection(conditioner)
51
+ y = x + diffusion_step
52
+
53
+ y = self.dilated_conv(y) + conditioner
54
+
55
+ gate, filter = torch.chunk(y, 2, dim=1)
56
+ y = torch.sigmoid(gate) * torch.tanh(filter)
57
+
58
+ y = self.output_projection(y)
59
+ residual, skip = torch.chunk(y, 2, dim=1)
60
+ return (x+residual) / math.sqrt(2.0), skip
61
+
62
+
63
+ class DiffSingerNet(nn.Module):
64
+ def __init__(
65
+ self,
66
+ in_dims=128,
67
+ residual_channels=256,
68
+ encoder_hidden=256,
69
+ dilation_cycle_length=4,
70
+ residual_layers=20,
71
+ ):
72
+ super().__init__()
73
+
74
+ # self.pe_scale = pe_scale
75
+
76
+ self.input_projection = nn.Conv1d(in_dims, residual_channels, 1)
77
+ self.time_pos_emb = SinusoidalPosEmb(residual_channels)
78
+ dim = residual_channels
79
+ self.mlp = nn.Sequential(
80
+ nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim)
81
+ )
82
+ self.residual_layers = nn.ModuleList([
83
+ ResidualBlock(
84
+ encoder_hidden, residual_channels,
85
+ 2**(i % dilation_cycle_length)
86
+ ) for i in range(residual_layers)
87
+ ])
88
+ self.skip_projection = nn.Conv1d(
89
+ residual_channels, residual_channels, 1
90
+ )
91
+ self.output_projection = nn.Conv1d(residual_channels, in_dims, 1)
92
+ nn.init.zeros_(self.output_projection.weight)
93
+
94
+ def forward(self, x, timesteps, context, x_mask=None, context_mask=None):
95
+ # make it compatible with int time step during inference
96
+ if timesteps.dim() == 0:
97
+ timesteps = timesteps.expand(x.shape[0]
98
+ ).to(x.device, dtype=torch.long)
99
+
100
+ x = self.input_projection(x) # x [B, residual_channel, T]
101
+
102
+ x = F.relu(x)
103
+
104
+ t = self.time_pos_emb(timesteps)
105
+ t = self.mlp(t)
106
+
107
+ cond = context
108
+
109
+ skip = []
110
+ for layer_id, layer in enumerate(self.residual_layers):
111
+ x, skip_connection = layer(x, cond, t)
112
+ skip.append(skip_connection)
113
+
114
+ x = torch.sum(torch.stack(skip),
115
+ dim=0) / math.sqrt(len(self.residual_layers))
116
+ x = self.skip_projection(x)
117
+ x = F.relu(x)
118
+ x = self.output_projection(x) # [B, M, T]
119
+ return x * x_mask.unsqueeze(1)
models/diffusion.py ADDED
@@ -0,0 +1,1261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import random
3
+ from typing import Any
4
+ from pathlib import Path
5
+
6
+ from tqdm import tqdm
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import diffusers.schedulers as noise_schedulers
11
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+
14
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
15
+ from models.content_encoder.content_encoder import ContentEncoder
16
+ from models.content_adapter import ContentAdapterBase
17
+ from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase
18
+ from utils.torch_utilities import (
19
+ create_alignment_path, create_mask_from_length, loss_with_mask,
20
+ trim_or_pad_length
21
+ )
22
+ from safetensors.torch import load_file
23
+
24
+ class DiffusionMixin:
25
+ def __init__(
26
+ self,
27
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
28
+ snr_gamma: float = None,
29
+ cfg_drop_ratio: float = 0.2
30
+ ) -> None:
31
+ self.noise_scheduler_name = noise_scheduler_name
32
+ self.snr_gamma = snr_gamma
33
+ self.classifier_free_guidance = cfg_drop_ratio > 0.0
34
+ self.cfg_drop_ratio = cfg_drop_ratio
35
+ self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained(
36
+ self.noise_scheduler_name, subfolder="scheduler"
37
+ )
38
+
39
+ def compute_snr(self, timesteps) -> torch.Tensor:
40
+ """
41
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
42
+ """
43
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
44
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
45
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5
46
+
47
+ # Expand the tensors.
48
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
49
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device
50
+ )[timesteps].float()
51
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
52
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
53
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
54
+
55
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
56
+ device=timesteps.device
57
+ )[timesteps].float()
58
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
59
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[...,
60
+ None]
61
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
62
+
63
+ # Compute SNR.
64
+ snr = (alpha / sigma)**2
65
+ return snr
66
+
67
+ def get_timesteps(
68
+ self,
69
+ batch_size: int,
70
+ device: torch.device,
71
+ training: bool = True
72
+ ) -> torch.Tensor:
73
+ if training:
74
+ timesteps = torch.randint(
75
+ 0,
76
+ self.noise_scheduler.config.num_train_timesteps,
77
+ (batch_size, ),
78
+ device=device
79
+ )
80
+ else:
81
+ # validation on half of the total timesteps
82
+ timesteps = (self.noise_scheduler.config.num_train_timesteps //
83
+ 2) * torch.ones((batch_size, ),
84
+ dtype=torch.int64,
85
+ device=device)
86
+
87
+ timesteps = timesteps.long()
88
+ return timesteps
89
+
90
+ def get_target(
91
+ self, latent: torch.Tensor, noise: torch.Tensor,
92
+ timesteps: torch.Tensor
93
+ ) -> torch.Tensor:
94
+ """
95
+ Get the target for loss depending on the prediction type
96
+ """
97
+ if self.noise_scheduler.config.prediction_type == "epsilon":
98
+ target = noise
99
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
100
+ target = self.noise_scheduler.get_velocity(
101
+ latent, noise, timesteps
102
+ )
103
+ else:
104
+ raise ValueError(
105
+ f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
106
+ )
107
+ return target
108
+
109
+ def loss_with_snr(
110
+ self, pred: torch.Tensor, target: torch.Tensor,
111
+ timesteps: torch.Tensor, mask: torch.Tensor,
112
+ loss_reduce: bool = True,
113
+ ) -> torch.Tensor:
114
+ if self.snr_gamma is None:
115
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
116
+ loss = loss_with_mask(loss, mask, reduce=loss_reduce)
117
+ else:
118
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
119
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L1006
120
+ snr = self.compute_snr(timesteps)
121
+ mse_loss_weights = torch.stack(
122
+ [
123
+ snr,
124
+ self.snr_gamma * torch.ones_like(timesteps),
125
+ ],
126
+ dim=1,
127
+ ).min(dim=1)[0]
128
+ # division by (snr + 1) does not work well, not clear about the reason
129
+ mse_loss_weights = mse_loss_weights / snr
130
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
131
+ loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights
132
+ if loss_reduce:
133
+ loss = loss.mean()
134
+ return loss
135
+
136
+ def rescale_cfg(
137
+ self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor,
138
+ guidance_rescale: float
139
+ ):
140
+ """
141
+ Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
142
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
143
+ """
144
+ std_cond = pred_cond.std(
145
+ dim=list(range(1, pred_cond.ndim)), keepdim=True
146
+ )
147
+ std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)
148
+
149
+ pred_rescaled = pred_cfg * (std_cond / std_cfg)
150
+ pred_cfg = guidance_rescale * pred_rescaled + (
151
+ 1 - guidance_rescale
152
+ ) * pred_cfg
153
+ return pred_cfg
154
+
155
+ class CrossAttentionAudioDiffusion(
156
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
157
+ DiffusionMixin
158
+ ):
159
+ def __init__(
160
+ self,
161
+ autoencoder: AutoEncoderBase,
162
+ content_encoder: ContentEncoder,
163
+ content_adapter: ContentAdapterBase,
164
+ backbone: nn.Module,
165
+ duration_offset: float = 1.0,
166
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
167
+ snr_gamma: float = None,
168
+ cfg_drop_ratio: float = 0.2,
169
+ ):
170
+ nn.Module.__init__(self)
171
+ DiffusionMixin.__init__(
172
+ self, noise_scheduler_name, snr_gamma, cfg_drop_ratio
173
+ )
174
+
175
+ self.autoencoder = autoencoder
176
+ for param in self.autoencoder.parameters():
177
+ param.requires_grad = False
178
+
179
+ self.content_encoder = content_encoder
180
+ self.content_encoder.audio_encoder.model = self.autoencoder
181
+ self.content_adapter = content_adapter
182
+ self.backbone = backbone
183
+ self.duration_offset = duration_offset
184
+ self.dummy_param = nn.Parameter(torch.empty(0))
185
+
186
+ def forward(
187
+ self, content: list[Any], task: list[str], waveform: torch.Tensor,
188
+ waveform_lengths: torch.Tensor, instruction: torch.Tensor,
189
+ instruction_lengths: Sequence[int], **kwargs
190
+ ):
191
+ device = self.dummy_param.device
192
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
193
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
194
+
195
+ self.autoencoder.eval()
196
+ with torch.no_grad():
197
+ latent, latent_mask = self.autoencoder.encode(
198
+ waveform.unsqueeze(1), waveform_lengths
199
+ )
200
+
201
+ content_output: dict[
202
+ str, torch.Tensor] = self.content_encoder.encode_content(
203
+ content, task, device=device
204
+ )
205
+ content, content_mask = content_output["content"], content_output[
206
+ "content_mask"]
207
+ instruction_mask = create_mask_from_length(instruction_lengths)
208
+ content, content_mask, global_duration_pred, _ = \
209
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
210
+ global_duration_target = torch.log(
211
+ latent_mask.sum(1) / self.autoencoder.latent_token_rate +
212
+ self.duration_offset
213
+ )
214
+ global_duration_loss = F.mse_loss(
215
+ global_duration_target, global_duration_pred
216
+ )
217
+
218
+ if self.training and self.classifier_free_guidance:
219
+ mask_indices = [
220
+ k for k in range(len(waveform))
221
+ if random.random() < self.cfg_drop_ratio
222
+ ]
223
+ if len(mask_indices) > 0:
224
+ content[mask_indices] = 0
225
+
226
+ batch_size = latent.shape[0]
227
+ timesteps = self.get_timesteps(batch_size, device, self.training)
228
+ noise = torch.randn_like(latent)
229
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
230
+ target = self.get_target(latent, noise, timesteps)
231
+
232
+ pred: torch.Tensor = self.backbone(
233
+ x=noisy_latent,
234
+ timesteps=timesteps,
235
+ context=content,
236
+ x_mask=latent_mask,
237
+ context_mask=content_mask
238
+ )
239
+
240
+ pred = pred.transpose(1, self.autoencoder.time_dim)
241
+ target = target.transpose(1, self.autoencoder.time_dim)
242
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
243
+
244
+ return {
245
+ "diff_loss": diff_loss,
246
+ "global_duration_loss": global_duration_loss,
247
+ }
248
+
249
+ @torch.no_grad()
250
+ def inference(
251
+ self,
252
+ content: list[Any],
253
+ condition: list[Any],
254
+ task: list[str],
255
+ instruction: torch.Tensor,
256
+ instruction_lengths: Sequence[int],
257
+ scheduler: SchedulerMixin,
258
+ num_steps: int = 20,
259
+ guidance_scale: float = 3.0,
260
+ guidance_rescale: float = 0.0,
261
+ disable_progress: bool = True,
262
+ **kwargs
263
+ ):
264
+ device = self.dummy_param.device
265
+ classifier_free_guidance = guidance_scale > 1.0
266
+
267
+ content_output: dict[
268
+ str, torch.Tensor] = self.content_encoder.encode_content(
269
+ content, task, device=device
270
+ )
271
+ content, content_mask = content_output["content"], content_output[
272
+ "content_mask"]
273
+
274
+ instruction_mask = create_mask_from_length(instruction_lengths)
275
+ content, content_mask, global_duration_pred, _ = \
276
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
277
+ batch_size = content.size(0)
278
+
279
+ if classifier_free_guidance:
280
+ uncond_content = torch.zeros_like(content)
281
+ uncond_content_mask = content_mask.detach().clone()
282
+ content = torch.cat([uncond_content, content])
283
+ content_mask = torch.cat([uncond_content_mask, content_mask])
284
+
285
+ scheduler.set_timesteps(num_steps, device=device)
286
+ timesteps = scheduler.timesteps
287
+
288
+ global_duration_pred = torch.exp(
289
+ global_duration_pred
290
+ ) - self.duration_offset
291
+ global_duration_pred *= self.autoencoder.latent_token_rate
292
+ global_duration_pred = torch.round(global_duration_pred)
293
+
294
+ latent_shape = tuple(
295
+ int(global_duration_pred.max().item()) if dim is None else dim
296
+ for dim in self.autoencoder.latent_shape
297
+ )
298
+ latent = self.prepare_latent(
299
+ batch_size, scheduler, latent_shape, content.dtype, device
300
+ )
301
+ latent_mask = create_mask_from_length(global_duration_pred).to(
302
+ content_mask.device
303
+ )
304
+ if classifier_free_guidance:
305
+ latent_mask = torch.cat([latent_mask, latent_mask])
306
+
307
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
308
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
309
+
310
+ for i, timestep in enumerate(timesteps):
311
+ # expand the latent if we are doing classifier free guidance
312
+ latent_input = torch.cat([latent, latent]
313
+ ) if classifier_free_guidance else latent
314
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
315
+
316
+ noise_pred = self.backbone(
317
+ x=latent_input,
318
+ x_mask=latent_mask,
319
+ timesteps=timestep,
320
+ context=content,
321
+ context_mask=content_mask,
322
+ )
323
+
324
+ # perform guidance
325
+ if classifier_free_guidance:
326
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
327
+ noise_pred = noise_pred_uncond + guidance_scale * (
328
+ noise_pred_content - noise_pred_uncond
329
+ )
330
+ if guidance_rescale != 0.0:
331
+ noise_pred = self.rescale_cfg(
332
+ noise_pred_content, noise_pred, guidance_rescale
333
+ )
334
+
335
+ # compute the previous noisy sample x_t -> x_t-1
336
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
337
+
338
+ # call the callback, if provided
339
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
340
+ (i + 1) % scheduler.order == 0):
341
+ progress_bar.update(1)
342
+
343
+ waveform = self.autoencoder.decode(latent)
344
+
345
+ return waveform
346
+
347
+ def prepare_latent(
348
+ self, batch_size: int, scheduler: SchedulerMixin,
349
+ latent_shape: Sequence[int], dtype: torch.dtype, device: str
350
+ ):
351
+ shape = (batch_size, *latent_shape)
352
+ latent = randn_tensor(
353
+ shape, generator=None, device=device, dtype=dtype
354
+ )
355
+ # scale the initial noise by the standard deviation required by the scheduler
356
+ latent = latent * scheduler.init_noise_sigma
357
+ return latent
358
+
359
+ class SingleTaskCrossAttentionAudioDiffusion(CrossAttentionAudioDiffusion
360
+ ):
361
+ def __init__(
362
+ self,
363
+ autoencoder: AutoEncoderBase,
364
+ content_encoder: ContentEncoder,
365
+ backbone: nn.Module,
366
+ pretrained_ckpt: str | Path = None,
367
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
368
+ snr_gamma: float = None,
369
+ cfg_drop_ratio: float = 0.2,
370
+ ):
371
+ nn.Module.__init__(self)
372
+ DiffusionMixin.__init__(
373
+ self, noise_scheduler_name, snr_gamma, cfg_drop_ratio
374
+ )
375
+
376
+ self.autoencoder = autoencoder
377
+ for param in self.autoencoder.parameters():
378
+ param.requires_grad = False
379
+
380
+ self.backbone = backbone
381
+ if pretrained_ckpt is not None:
382
+ pretrained_state_dict = load_file(pretrained_ckpt)
383
+ self.load_pretrained(pretrained_state_dict)
384
+
385
+ self.content_encoder = content_encoder
386
+ #self.content_encoder.audio_encoder.model = self.autoencoder
387
+ self.dummy_param = nn.Parameter(torch.empty(0))
388
+
389
+ def forward(
390
+ self, content: list[Any], condition: list[Any], task: list[str], waveform: torch.Tensor,
391
+ waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs
392
+ ):
393
+ loss_reduce = self.training or (loss_reduce and not self.training)
394
+ device = self.dummy_param.device
395
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
396
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
397
+
398
+ self.autoencoder.eval()
399
+ with torch.no_grad():
400
+ latent, latent_mask = self.autoencoder.encode(
401
+ waveform.unsqueeze(1), waveform_lengths
402
+ )
403
+
404
+ content_output: dict[
405
+ str, torch.Tensor] = self.content_encoder.encode_content(
406
+ content, task, device=device
407
+ )
408
+ content, content_mask = content_output["content"], content_output[
409
+ "content_mask"]
410
+
411
+ if self.training and self.classifier_free_guidance:
412
+ mask_indices = [
413
+ k for k in range(len(waveform))
414
+ if random.random() < self.cfg_drop_ratio
415
+ ]
416
+ if len(mask_indices) > 0:
417
+ content[mask_indices] = 0
418
+
419
+ batch_size = latent.shape[0]
420
+ timesteps = self.get_timesteps(batch_size, device, self.training)
421
+ noise = torch.randn_like(latent)
422
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
423
+ target = self.get_target(latent, noise, timesteps)
424
+
425
+ pred: torch.Tensor = self.backbone(
426
+ x=noisy_latent,
427
+ timesteps=timesteps,
428
+ context=content,
429
+ x_mask=latent_mask,
430
+ context_mask=content_mask
431
+ )
432
+
433
+ pred = pred.transpose(1, self.autoencoder.time_dim)
434
+ target = target.transpose(1, self.autoencoder.time_dim)
435
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask, loss_reduce=loss_reduce)
436
+
437
+ return {
438
+ "diff_loss": diff_loss,
439
+ }
440
+
441
+ @torch.no_grad()
442
+ def inference(
443
+ self,
444
+ content: list[Any],
445
+ condition: list[Any],
446
+ task: list[str],
447
+ scheduler: SchedulerMixin,
448
+ latent_shape: Sequence[int],
449
+ num_steps: int = 20,
450
+ guidance_scale: float = 3.0,
451
+ guidance_rescale: float = 0.0,
452
+ disable_progress: bool = True,
453
+ **kwargs
454
+ ):
455
+ device = self.dummy_param.device
456
+ classifier_free_guidance = guidance_scale > 1.0
457
+
458
+ content_output: dict[
459
+ str, torch.Tensor] = self.content_encoder.encode_content(
460
+ content, task, device=device
461
+ )
462
+ content, content_mask = content_output["content"], content_output[
463
+ "content_mask"]
464
+ batch_size = content.size(0)
465
+
466
+ if classifier_free_guidance:
467
+ uncond_content = torch.zeros_like(content)
468
+ uncond_content_mask = content_mask.detach().clone()
469
+ content = torch.cat([uncond_content, content])
470
+ content_mask = torch.cat([uncond_content_mask, content_mask])
471
+
472
+ scheduler.set_timesteps(num_steps, device=device)
473
+ timesteps = scheduler.timesteps
474
+
475
+ latent = self.prepare_latent(
476
+ batch_size, scheduler, latent_shape, content.dtype, device
477
+ )
478
+
479
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
480
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
481
+
482
+ for i, timestep in enumerate(timesteps):
483
+ # expand the latent if we are doing classifier free guidance
484
+ latent_input = torch.cat([latent, latent]
485
+ ) if classifier_free_guidance else latent
486
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
487
+
488
+ noise_pred = self.backbone(
489
+ x=latent_input,
490
+ timesteps=timestep,
491
+ context=content,
492
+ context_mask=content_mask,
493
+ )
494
+
495
+ # perform guidance
496
+ if classifier_free_guidance:
497
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
498
+ noise_pred = noise_pred_uncond + guidance_scale * (
499
+ noise_pred_content - noise_pred_uncond
500
+ )
501
+ if guidance_rescale != 0.0:
502
+ noise_pred = self.rescale_cfg(
503
+ noise_pred_content, noise_pred, guidance_rescale
504
+ )
505
+
506
+ # compute the previous noisy sample x_t -> x_t-1
507
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
508
+
509
+ # call the callback, if provided
510
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
511
+ (i + 1) % scheduler.order == 0):
512
+ progress_bar.update(1)
513
+
514
+ waveform = self.autoencoder.decode(latent)
515
+
516
+ return waveform
517
+
518
+
519
+ class DummyContentAudioDiffusion(CrossAttentionAudioDiffusion):
520
+ def __init__(
521
+ self,
522
+ autoencoder: AutoEncoderBase,
523
+ content_encoder: ContentEncoder,
524
+ content_adapter: ContentAdapterBase,
525
+ backbone: nn.Module,
526
+ content_dim: int,
527
+ frame_resolution: float,
528
+ duration_offset: float = 1.0,
529
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
530
+ snr_gamma: float = None,
531
+ cfg_drop_ratio: float = 0.2,
532
+ ):
533
+ """
534
+ Args:
535
+ autoencoder:
536
+ Pretrained audio autoencoder that encodes raw waveforms into latent
537
+ space and decodes latents back to waveforms.
538
+ content_encoder:
539
+ Module that produces content embeddings (e.g., from text, MIDI, or
540
+ other modalities) used to guide the diffusion.
541
+ content_adapter (ContentAdapterBase):
542
+ Adapter module that fuses task instruction embeddings and content embeddings,
543
+ and performs duration prediction for time-aligned tasks.
544
+ backbone:
545
+ U‑Net or Transformer backbone that performs the core denoising
546
+ operations in latent space.
547
+ content_dim:
548
+ Dimension of the content embeddings produced by the `content_encoder`
549
+ and `content_adapter`.
550
+ frame_resolution:
551
+ Time resolution, in seconds, of each content frame when predicting
552
+ duration alignment. Used when calculating duration loss.
553
+ duration_offset:
554
+ A small positive offset (frame number) added to predicted durations
555
+ to ensure numerical stability of log-scaled duration prediction.
556
+ noise_scheduler_name:
557
+ Identifier of the pretrained noise scheduler to use.
558
+ snr_gamma:
559
+ Clipping value in min-SNR diffusion loss weighting strategy.
560
+ cfg_drop_ratio:
561
+ Probability of dropping the content conditioning during training
562
+ to support CFG.
563
+ """
564
+ super().__init__(
565
+ autoencoder=autoencoder,
566
+ content_encoder=content_encoder,
567
+ content_adapter=content_adapter,
568
+ backbone=backbone,
569
+ duration_offset=duration_offset,
570
+ noise_scheduler_name=noise_scheduler_name,
571
+ snr_gamma=snr_gamma,
572
+ cfg_drop_ratio=cfg_drop_ratio,
573
+ )
574
+ self.frame_resolution = frame_resolution
575
+ self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
576
+ self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
577
+
578
+ def forward(
579
+ self, content, duration, task, is_time_aligned, waveform,
580
+ waveform_lengths, instruction, instruction_lengths, **kwargs
581
+ ):
582
+ device = self.dummy_param.device
583
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
584
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
585
+
586
+ self.autoencoder.eval()
587
+ with torch.no_grad():
588
+ latent, latent_mask = self.autoencoder.encode(
589
+ waveform.unsqueeze(1), waveform_lengths
590
+ )
591
+
592
+ # content: (B, L, E)
593
+ content_output: dict[
594
+ str, torch.Tensor] = self.content_encoder.encode_content(
595
+ content, task, device=device
596
+ )
597
+ length_aligned_content = content_output["length_aligned_content"]
598
+ content, content_mask = content_output["content"], content_output[
599
+ "content_mask"]
600
+ instruction_mask = create_mask_from_length(instruction_lengths)
601
+
602
+ content, content_mask, global_duration_pred, local_duration_pred = \
603
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
604
+
605
+ n_frames = torch.round(duration / self.frame_resolution)
606
+ local_duration_target = torch.log(n_frames + self.duration_offset)
607
+ global_duration_target = torch.log(
608
+ latent_mask.sum(1) / self.autoencoder.latent_token_rate +
609
+ self.duration_offset
610
+ )
611
+
612
+ # truncate unused non time aligned duration prediction
613
+ if is_time_aligned.sum() > 0:
614
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
615
+ else:
616
+ trunc_ta_length = content.size(1)
617
+
618
+ # local duration loss
619
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
620
+ ta_content_mask = content_mask[:, :trunc_ta_length]
621
+ local_duration_target = local_duration_target.to(
622
+ dtype=local_duration_pred.dtype
623
+ )
624
+ local_duration_loss = loss_with_mask(
625
+ (local_duration_target - local_duration_pred)**2,
626
+ ta_content_mask,
627
+ reduce=False
628
+ )
629
+ local_duration_loss *= is_time_aligned
630
+ if is_time_aligned.sum().item() == 0:
631
+ local_duration_loss *= 0.0
632
+ local_duration_loss = local_duration_loss.mean()
633
+ else:
634
+ local_duration_loss = local_duration_loss.sum(
635
+ ) / is_time_aligned.sum()
636
+
637
+ # global duration loss
638
+ global_duration_loss = F.mse_loss(
639
+ global_duration_target, global_duration_pred
640
+ )
641
+
642
+ # --------------------------------------------------------------------
643
+ # prepare latent and diffusion-related noise
644
+ # --------------------------------------------------------------------
645
+
646
+ batch_size = latent.shape[0]
647
+ timesteps = self.get_timesteps(batch_size, device, self.training)
648
+ noise = torch.randn_like(latent)
649
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
650
+ target = self.get_target(latent, noise, timesteps)
651
+
652
+ # --------------------------------------------------------------------
653
+ # duration adapter
654
+ # --------------------------------------------------------------------
655
+ if is_time_aligned.sum() == 0 and \
656
+ duration.size(1) < content_mask.size(1):
657
+ # for non time-aligned tasks like TTA, `duration` is dummy one
658
+ duration = F.pad(
659
+ duration, (0, content_mask.size(1) - duration.size(1))
660
+ )
661
+ n_latents = torch.round(duration * self.autoencoder.latent_token_rate)
662
+ # content_mask: [B, L], helper_latent_mask: [B, T]
663
+ helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to(
664
+ content_mask.device
665
+ )
666
+ attn_mask = ta_content_mask.unsqueeze(
667
+ -1
668
+ ) * helper_latent_mask.unsqueeze(1)
669
+ # attn_mask: [B, L, T]
670
+ align_path = create_alignment_path(n_latents, attn_mask)
671
+ time_aligned_content = content[:, :trunc_ta_length]
672
+ time_aligned_content = torch.matmul(
673
+ align_path.transpose(1, 2).to(content.dtype), time_aligned_content
674
+ ) # (B, T, L) x (B, L, E) -> (B, T, E)
675
+
676
+ # --------------------------------------------------------------------
677
+ # prepare input to the backbone
678
+ # --------------------------------------------------------------------
679
+ # TODO compatility for 2D spectrogram VAE
680
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
681
+ time_aligned_content = trim_or_pad_length(
682
+ time_aligned_content, latent_length, 1
683
+ )
684
+ length_aligned_content = trim_or_pad_length(
685
+ length_aligned_content, latent_length, 1
686
+ )
687
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
688
+ # length_aligned_content: from aligned input (f0/energy)
689
+ time_aligned_content = time_aligned_content + length_aligned_content
690
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
691
+ time_aligned_content.dtype
692
+ )
693
+
694
+ context = content
695
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
696
+ # only use the first dummy non time aligned embedding
697
+ context_mask = content_mask.detach().clone()
698
+ context_mask[is_time_aligned, 1:] = False
699
+
700
+ # truncate dummy non time aligned context
701
+ if is_time_aligned.sum().item() < batch_size:
702
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
703
+ else:
704
+ trunc_nta_length = content.size(1)
705
+ context = context[:, :trunc_nta_length]
706
+ context_mask = context_mask[:, :trunc_nta_length]
707
+
708
+ # --------------------------------------------------------------------
709
+ # classifier free guidance
710
+ # --------------------------------------------------------------------
711
+ if self.training and self.classifier_free_guidance:
712
+ mask_indices = [
713
+ k for k in range(len(waveform))
714
+ if random.random() < self.cfg_drop_ratio
715
+ ]
716
+ if len(mask_indices) > 0:
717
+ context[mask_indices] = 0
718
+ time_aligned_content[mask_indices] = 0
719
+
720
+ pred: torch.Tensor = self.backbone(
721
+ x=noisy_latent,
722
+ timesteps=timesteps,
723
+ time_aligned_context=time_aligned_content,
724
+ context=context,
725
+ x_mask=latent_mask,
726
+ context_mask=context_mask
727
+ )
728
+ pred = pred.transpose(1, self.autoencoder.time_dim)
729
+ target = target.transpose(1, self.autoencoder.time_dim)
730
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
731
+ return {
732
+ "diff_loss": diff_loss,
733
+ "local_duration_loss": local_duration_loss,
734
+ "global_duration_loss": global_duration_loss
735
+ }
736
+
737
+ @torch.no_grad()
738
+ def inference(
739
+ self,
740
+ content: list[Any],
741
+ condition: list[Any],
742
+ task: list[str],
743
+ is_time_aligned: list[bool],
744
+ instruction: torch.Tensor,
745
+ instruction_lengths: Sequence[int],
746
+ scheduler: SchedulerMixin,
747
+ num_steps: int = 20,
748
+ guidance_scale: float = 3.0,
749
+ guidance_rescale: float = 0.0,
750
+ disable_progress: bool = True,
751
+ use_gt_duration: bool = False,
752
+ **kwargs
753
+ ):
754
+ device = self.dummy_param.device
755
+ classifier_free_guidance = guidance_scale > 1.0
756
+
757
+ content_output: dict[
758
+ str, torch.Tensor] = self.content_encoder.encode_content(
759
+ content, task, device=device
760
+ )
761
+ length_aligned_content = content_output["length_aligned_content"]
762
+ content, content_mask = content_output["content"], content_output[
763
+ "content_mask"]
764
+ instruction_mask = create_mask_from_length(instruction_lengths)
765
+ content, content_mask, global_duration_pred, local_duration_pred = \
766
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
767
+
768
+ scheduler.set_timesteps(num_steps, device=device)
769
+ timesteps = scheduler.timesteps
770
+ batch_size = content.size(0)
771
+
772
+ # truncate dummy time aligned duration prediction
773
+ is_time_aligned = torch.as_tensor(is_time_aligned)
774
+ if is_time_aligned.sum() > 0:
775
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
776
+ else:
777
+ trunc_ta_length = content.size(1)
778
+
779
+ # prepare local duration
780
+ local_duration_pred = torch.exp(local_duration_pred) * content_mask
781
+ local_duration_pred = torch.ceil(
782
+ local_duration_pred
783
+ ) - self.duration_offset # frame number in `self.frame_resolution`
784
+ local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \
785
+ self.autoencoder.latent_token_rate)
786
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
787
+ # use ground truth duration
788
+ if use_gt_duration and "duration" in kwargs:
789
+ local_duration_pred = torch.round(
790
+ torch.as_tensor(kwargs["duration"]) *
791
+ self.autoencoder.latent_token_rate
792
+ ).to(device)
793
+
794
+ # prepare global duration
795
+ global_duration = local_duration_pred.sum(1)
796
+ global_duration_pred = torch.exp(
797
+ global_duration_pred
798
+ ) - self.duration_offset
799
+ global_duration_pred *= self.autoencoder.latent_token_rate
800
+ global_duration_pred = torch.round(global_duration_pred)
801
+ global_duration[~is_time_aligned] = global_duration_pred[
802
+ ~is_time_aligned]
803
+
804
+ # --------------------------------------------------------------------
805
+ # duration adapter
806
+ # --------------------------------------------------------------------
807
+ time_aligned_content = content[:, :trunc_ta_length]
808
+ ta_content_mask = content_mask[:, :trunc_ta_length]
809
+ latent_mask = create_mask_from_length(global_duration).to(
810
+ content_mask.device
811
+ )
812
+ attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
813
+ # attn_mask: [B, L, T]
814
+ align_path = create_alignment_path(local_duration_pred, attn_mask)
815
+ time_aligned_content = torch.matmul(
816
+ align_path.transpose(1, 2).to(content.dtype), time_aligned_content
817
+ ) # (B, T, L) x (B, L, E) -> (B, T, E)
818
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
819
+ time_aligned_content.dtype
820
+ )
821
+
822
+ length_aligned_content = trim_or_pad_length(
823
+ length_aligned_content, time_aligned_content.size(1), 1
824
+ )
825
+ time_aligned_content = time_aligned_content + length_aligned_content
826
+
827
+ # --------------------------------------------------------------------
828
+ # prepare unconditional input
829
+ # --------------------------------------------------------------------
830
+ context = content
831
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
832
+ context_mask = content_mask
833
+ context_mask[
834
+ is_time_aligned,
835
+ 1:] = False # only use the first dummy non time aligned embedding
836
+ # truncate dummy non time aligned context
837
+ if is_time_aligned.sum().item() < batch_size:
838
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
839
+ else:
840
+ trunc_nta_length = content.size(1)
841
+ context = context[:, :trunc_nta_length]
842
+ context_mask = context_mask[:, :trunc_nta_length]
843
+
844
+ if classifier_free_guidance:
845
+ uncond_time_aligned_content = torch.zeros_like(
846
+ time_aligned_content
847
+ )
848
+ uncond_context = torch.zeros_like(context)
849
+ uncond_context_mask = context_mask.detach().clone()
850
+ time_aligned_content = torch.cat([
851
+ uncond_time_aligned_content, time_aligned_content
852
+ ])
853
+ context = torch.cat([uncond_context, context])
854
+ context_mask = torch.cat([uncond_context_mask, context_mask])
855
+ latent_mask = torch.cat([
856
+ latent_mask, latent_mask.detach().clone()
857
+ ])
858
+
859
+ # --------------------------------------------------------------------
860
+ # prepare input to the backbone
861
+ # --------------------------------------------------------------------
862
+ latent_shape = tuple(
863
+ int(global_duration.max().item()) if dim is None else dim
864
+ for dim in self.autoencoder.latent_shape
865
+ )
866
+ shape = (batch_size, *latent_shape)
867
+ latent = randn_tensor(
868
+ shape, generator=None, device=device, dtype=content.dtype
869
+ )
870
+ # scale the initial noise by the standard deviation required by the scheduler
871
+ latent = latent * scheduler.init_noise_sigma
872
+
873
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
874
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
875
+ # --------------------------------------------------------------------
876
+ # iteratively denoising
877
+ # --------------------------------------------------------------------
878
+ for i, timestep in enumerate(timesteps):
879
+ # expand the latent if we are doing classifier free guidance
880
+ if classifier_free_guidance:
881
+ latent_input = torch.cat([latent, latent])
882
+ else:
883
+ latent_input = latent
884
+
885
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
886
+ noise_pred = self.backbone(
887
+ x=latent_input,
888
+ x_mask=latent_mask,
889
+ timesteps=timestep,
890
+ time_aligned_context=time_aligned_content,
891
+ context=context,
892
+ context_mask=context_mask
893
+ )
894
+
895
+ if classifier_free_guidance:
896
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
897
+ noise_pred = noise_pred_uncond + guidance_scale * (
898
+ noise_pred_cond - noise_pred_uncond
899
+ )
900
+ if guidance_rescale != 0.0:
901
+ noise_pred = self.rescale_cfg(
902
+ noise_pred_cond, noise_pred, guidance_rescale
903
+ )
904
+
905
+ # compute the previous noisy sample x_t -> x_t-1
906
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
907
+
908
+ # call the callback, if provided
909
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
910
+ (i + 1) % scheduler.order == 0):
911
+ progress_bar.update(1)
912
+
913
+ progress_bar.close()
914
+
915
+ # TODO variable length decoding, using `latent_mask`
916
+ waveform = self.autoencoder.decode(latent)
917
+ return waveform
918
+
919
+
920
+ class DoubleContentAudioDiffusion(CrossAttentionAudioDiffusion):
921
+ def __init__(
922
+ self,
923
+ autoencoder: AutoEncoderBase,
924
+ content_encoder: ContentEncoder,
925
+ content_adapter: nn.Module,
926
+ backbone: nn.Module,
927
+ content_dim: int,
928
+ frame_resolution: float,
929
+ duration_offset: float = 1.0,
930
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
931
+ snr_gamma: float = None,
932
+ cfg_drop_ratio: float = 0.2,
933
+ ):
934
+ super().__init__(
935
+ autoencoder=autoencoder,
936
+ content_encoder=content_encoder,
937
+ content_adapter=content_adapter,
938
+ backbone=backbone,
939
+ duration_offset=duration_offset,
940
+ noise_scheduler_name=noise_scheduler_name,
941
+ snr_gamma=snr_gamma,
942
+ cfg_drop_ratio=cfg_drop_ratio
943
+ )
944
+ self.frame_resolution = frame_resolution
945
+
946
+ def forward(
947
+ self, content, duration, task, is_time_aligned, waveform,
948
+ waveform_lengths, instruction, instruction_lengths, **kwargs
949
+ ):
950
+ device = self.dummy_param.device
951
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
952
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
953
+
954
+ self.autoencoder.eval()
955
+ with torch.no_grad():
956
+ latent, latent_mask = self.autoencoder.encode(
957
+ waveform.unsqueeze(1), waveform_lengths
958
+ )
959
+
960
+ content_output: dict[
961
+ str, torch.Tensor] = self.content_encoder.encode_content(
962
+ content, task, device=device
963
+ )
964
+ length_aligned_content = content_output["length_aligned_content"]
965
+ content, content_mask = content_output["content"], content_output[
966
+ "content_mask"]
967
+ context_mask = content_mask.detach()
968
+ instruction_mask = create_mask_from_length(instruction_lengths)
969
+
970
+ content, content_mask, global_duration_pred, local_duration_pred = \
971
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
972
+
973
+ # TODO if all non time aligned, content length > duration length
974
+
975
+ n_frames = torch.round(duration / self.frame_resolution)
976
+ local_duration_target = torch.log(n_frames + self.duration_offset)
977
+ global_duration_target = torch.log(
978
+ latent_mask.sum(1) / self.autoencoder.latent_token_rate +
979
+ self.duration_offset
980
+ )
981
+ # truncate unused non time aligned duration prediction
982
+ if is_time_aligned.sum() > 0:
983
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
984
+ else:
985
+ trunc_ta_length = content.size(1)
986
+ # local duration loss
987
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
988
+ ta_content_mask = content_mask[:, :trunc_ta_length]
989
+ local_duration_target = local_duration_target.to(
990
+ dtype=local_duration_pred.dtype
991
+ )
992
+ local_duration_loss = loss_with_mask(
993
+ (local_duration_target - local_duration_pred)**2,
994
+ ta_content_mask,
995
+ reduce=False
996
+ )
997
+ local_duration_loss *= is_time_aligned
998
+ if is_time_aligned.sum().item() == 0:
999
+ local_duration_loss *= 0.0
1000
+ local_duration_loss = local_duration_loss.mean()
1001
+ else:
1002
+ local_duration_loss = local_duration_loss.sum(
1003
+ ) / is_time_aligned.sum()
1004
+
1005
+ # global duration loss
1006
+ global_duration_loss = F.mse_loss(
1007
+ global_duration_target, global_duration_pred
1008
+ )
1009
+ # --------------------------------------------------------------------
1010
+ # prepare latent and diffusion-related noise
1011
+ # --------------------------------------------------------------------
1012
+ batch_size = latent.shape[0]
1013
+ timesteps = self.get_timesteps(batch_size, device, self.training)
1014
+ noise = torch.randn_like(latent)
1015
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
1016
+ target = self.get_target(latent, noise, timesteps)
1017
+
1018
+ # --------------------------------------------------------------------
1019
+ # duration adapter
1020
+ # --------------------------------------------------------------------
1021
+ # content_mask: [B, L], helper_latent_mask: [B, T]
1022
+ if is_time_aligned.sum() == 0 and \
1023
+ duration.size(1) < content_mask.size(1):
1024
+ # for non time-aligned tasks like TTA, `duration` is dummy one
1025
+ duration = F.pad(
1026
+ duration, (0, content_mask.size(1) - duration.size(1))
1027
+ )
1028
+ n_latents = torch.round(duration * self.autoencoder.latent_token_rate)
1029
+ helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to(
1030
+ content_mask.device
1031
+ )
1032
+ attn_mask = ta_content_mask.unsqueeze(
1033
+ -1
1034
+ ) * helper_latent_mask.unsqueeze(1)
1035
+ align_path = create_alignment_path(n_latents, attn_mask)
1036
+ time_aligned_content = content[:, :trunc_ta_length]
1037
+ time_aligned_content = torch.matmul(
1038
+ align_path.transpose(1, 2).to(content.dtype), time_aligned_content
1039
+ )
1040
+
1041
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
1042
+ time_aligned_content = trim_or_pad_length(
1043
+ time_aligned_content, latent_length, 1
1044
+ )
1045
+ length_aligned_content = trim_or_pad_length(
1046
+ length_aligned_content, latent_length, 1
1047
+ )
1048
+ time_aligned_content = time_aligned_content + length_aligned_content
1049
+ context = content
1050
+ # --------------------------------------------------------------------
1051
+ # classifier free guidance
1052
+ # --------------------------------------------------------------------
1053
+ if self.training and self.classifier_free_guidance:
1054
+ mask_indices = [
1055
+ k for k in range(len(waveform))
1056
+ if random.random() < self.cfg_drop_ratio
1057
+ ]
1058
+ if len(mask_indices) > 0:
1059
+ context[mask_indices] = 0
1060
+ time_aligned_content[mask_indices] = 0
1061
+
1062
+ pred: torch.Tensor = self.backbone(
1063
+ x=noisy_latent,
1064
+ timesteps=timesteps,
1065
+ time_aligned_context=time_aligned_content,
1066
+ context=context,
1067
+ x_mask=latent_mask,
1068
+ context_mask=context_mask,
1069
+ )
1070
+ pred = pred.transpose(1, self.autoencoder.time_dim)
1071
+ target = target.transpose(1, self.autoencoder.time_dim)
1072
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
1073
+ return {
1074
+ "diff_loss": diff_loss,
1075
+ "local_duration_loss": local_duration_loss,
1076
+ "global_duration_loss": global_duration_loss,
1077
+ }
1078
+
1079
+ @torch.no_grad()
1080
+ def inference(
1081
+ self,
1082
+ content: list[Any],
1083
+ condition: list[Any],
1084
+ task: list[str],
1085
+ is_time_aligned: list[bool],
1086
+ instruction: torch.Tensor,
1087
+ instruction_lengths: Sequence[int],
1088
+ scheduler: SchedulerMixin,
1089
+ num_steps: int = 20,
1090
+ guidance_scale: float = 3.0,
1091
+ guidance_rescale: float = 0.0,
1092
+ disable_progress: bool = True,
1093
+ use_gt_duration: bool = False,
1094
+ **kwargs
1095
+ ):
1096
+ device = self.dummy_param.device
1097
+ classifier_free_guidance = guidance_scale > 1.0
1098
+
1099
+ content_output: dict[
1100
+ str, torch.Tensor] = self.content_encoder.encode_content(
1101
+ content, task, device=device
1102
+ )
1103
+ length_aligned_content = content_output["length_aligned_content"]
1104
+ content, content_mask = content_output["content"], content_output[
1105
+ "content_mask"]
1106
+ instruction_mask = create_mask_from_length(instruction_lengths)
1107
+
1108
+ content, content_mask, global_duration_pred, local_duration_pred = \
1109
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
1110
+
1111
+ scheduler.set_timesteps(num_steps, device=device)
1112
+ timesteps = scheduler.timesteps
1113
+ batch_size = content.size(0)
1114
+
1115
+ # truncate dummy time aligned duration prediction
1116
+ is_time_aligned = torch.as_tensor(is_time_aligned)
1117
+ if is_time_aligned.sum() > 0:
1118
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
1119
+ else:
1120
+ trunc_ta_length = content.size(1)
1121
+
1122
+ # prepare local duration
1123
+ local_duration_pred = torch.exp(local_duration_pred) * content_mask
1124
+ local_duration_pred = torch.ceil(
1125
+ local_duration_pred
1126
+ ) - self.duration_offset # frame number in `self.frame_resolution`
1127
+ local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \
1128
+ self.autoencoder.latent_token_rate)
1129
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
1130
+ # use ground truth duration
1131
+ if use_gt_duration and "duration" in kwargs:
1132
+ local_duration_pred = torch.round(
1133
+ torch.as_tensor(kwargs["duration"]) *
1134
+ self.autoencoder.latent_token_rate
1135
+ ).to(device)
1136
+
1137
+ # prepare global duration
1138
+ global_duration = local_duration_pred.sum(1)
1139
+ global_duration_pred = torch.exp(
1140
+ global_duration_pred
1141
+ ) - self.duration_offset
1142
+ global_duration_pred *= self.autoencoder.latent_token_rate
1143
+ global_duration_pred = torch.round(global_duration_pred)
1144
+ global_duration[~is_time_aligned] = global_duration_pred[
1145
+ ~is_time_aligned]
1146
+
1147
+ # --------------------------------------------------------------------
1148
+ # duration adapter
1149
+ # --------------------------------------------------------------------
1150
+ time_aligned_content = content[:, :trunc_ta_length]
1151
+ ta_content_mask = content_mask[:, :trunc_ta_length]
1152
+ latent_mask = create_mask_from_length(global_duration).to(
1153
+ content_mask.device
1154
+ )
1155
+ attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
1156
+ # attn_mask: [B, L, T]
1157
+ align_path = create_alignment_path(local_duration_pred, attn_mask)
1158
+ time_aligned_content = torch.matmul(
1159
+ align_path.transpose(1, 2).to(content.dtype), time_aligned_content
1160
+ ) # (B, T, L) x (B, L, E) -> (B, T, E)
1161
+
1162
+ # time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
1163
+ # time_aligned_content.dtype
1164
+ # )
1165
+
1166
+ length_aligned_content = trim_or_pad_length(
1167
+ length_aligned_content, time_aligned_content.size(1), 1
1168
+ )
1169
+ time_aligned_content = time_aligned_content + length_aligned_content
1170
+
1171
+ # --------------------------------------------------------------------
1172
+ # prepare unconditional input
1173
+ # --------------------------------------------------------------------
1174
+ context = content
1175
+ # context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
1176
+ context_mask = content_mask
1177
+ # context_mask[
1178
+ # is_time_aligned,
1179
+ # 1:] = False # only use the first dummy non time aligned embedding
1180
+ # # truncate dummy non time aligned context
1181
+ # if is_time_aligned.sum().item() < batch_size:
1182
+ # trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
1183
+ # else:
1184
+ # trunc_nta_length = content.size(1)
1185
+ # context = context[:, :trunc_nta_length]
1186
+ # context_mask = context_mask[:, :trunc_nta_length]
1187
+
1188
+ if classifier_free_guidance:
1189
+ uncond_time_aligned_content = torch.zeros_like(
1190
+ time_aligned_content
1191
+ )
1192
+ uncond_context = torch.zeros_like(context)
1193
+ uncond_context_mask = context_mask.detach().clone()
1194
+ time_aligned_content = torch.cat([
1195
+ uncond_time_aligned_content, time_aligned_content
1196
+ ])
1197
+ context = torch.cat([uncond_context, context])
1198
+ context_mask = torch.cat([uncond_context_mask, context_mask])
1199
+ latent_mask = torch.cat([
1200
+ latent_mask, latent_mask.detach().clone()
1201
+ ])
1202
+
1203
+ # --------------------------------------------------------------------
1204
+ # prepare input to the backbone
1205
+ # --------------------------------------------------------------------
1206
+ latent_shape = tuple(
1207
+ int(global_duration.max().item()) if dim is None else dim
1208
+ for dim in self.autoencoder.latent_shape
1209
+ )
1210
+ shape = (batch_size, *latent_shape)
1211
+ latent = randn_tensor(
1212
+ shape, generator=None, device=device, dtype=content.dtype
1213
+ )
1214
+ # scale the initial noise by the standard deviation required by the scheduler
1215
+ latent = latent * scheduler.init_noise_sigma
1216
+
1217
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
1218
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
1219
+ # --------------------------------------------------------------------
1220
+ # iteratively denoising
1221
+ # --------------------------------------------------------------------
1222
+ for i, timestep in enumerate(timesteps):
1223
+ # expand the latent if we are doing classifier free guidance
1224
+ if classifier_free_guidance:
1225
+ latent_input = torch.cat([latent, latent])
1226
+ else:
1227
+ latent_input = latent
1228
+
1229
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
1230
+ noise_pred = self.backbone(
1231
+ x=latent_input,
1232
+ x_mask=latent_mask,
1233
+ timesteps=timestep,
1234
+ time_aligned_context=time_aligned_content,
1235
+ context=context,
1236
+ context_mask=context_mask
1237
+ )
1238
+
1239
+ if classifier_free_guidance:
1240
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
1241
+ noise_pred = noise_pred_uncond + guidance_scale * (
1242
+ noise_pred_cond - noise_pred_uncond
1243
+ )
1244
+ if guidance_rescale != 0.0:
1245
+ noise_pred = self.rescale_cfg(
1246
+ noise_pred_cond, noise_pred, guidance_rescale
1247
+ )
1248
+
1249
+ # compute the previous noisy sample x_t -> x_t-1
1250
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
1251
+
1252
+ # call the callback, if provided
1253
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
1254
+ (i + 1) % scheduler.order == 0):
1255
+ progress_bar.update(1)
1256
+
1257
+ progress_bar.close()
1258
+
1259
+ # TODO variable length decoding, using `latent_mask`
1260
+ waveform = self.autoencoder.decode(latent)
1261
+ return waveform
models/dit/__pycache__/attention.cpython-310.pyc ADDED
Binary file (7.69 kB). View file
 
models/dit/__pycache__/audio_dit.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
models/dit/__pycache__/mask_dit.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
models/dit/__pycache__/modules.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
models/dit/__pycache__/rotary.cpython-310.pyc ADDED
Binary file (2.77 kB). View file
 
models/dit/__pycache__/span_mask.cpython-310.pyc ADDED
Binary file (4.73 kB). View file
 
models/dit/attention.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ import einops
6
+ from einops import rearrange, repeat
7
+ from inspect import isfunction
8
+ from .rotary import RotaryEmbedding
9
+ from .modules import RMSNorm
10
+
11
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
12
+ ATTENTION_MODE = 'flash'
13
+ else:
14
+ ATTENTION_MODE = 'math'
15
+ print(f'attention mode is {ATTENTION_MODE}')
16
+
17
+
18
+ def add_mask(sim, mask):
19
+ b, ndim = sim.shape[0], mask.ndim
20
+ if ndim == 3:
21
+ mask = rearrange(mask, "b n m -> b 1 n m")
22
+ if ndim == 2:
23
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
24
+ max_neg_value = -torch.finfo(sim.dtype).max
25
+ sim = sim.masked_fill(~mask, max_neg_value)
26
+ return sim
27
+
28
+
29
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
30
+ def default(val, d):
31
+ return val if val is not None else (d() if isfunction(d) else d)
32
+
33
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
34
+ q_mask = default(
35
+ q_mask, torch.ones((b, i), device=device, dtype=torch.bool)
36
+ )
37
+ k_mask = default(
38
+ k_mask, torch.ones((b, j), device=device, dtype=torch.bool)
39
+ )
40
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1'
41
+ ) * rearrange(k_mask, 'b j -> b 1 1 j')
42
+ return attn_mask
43
+
44
+
45
+ class Attention(nn.Module):
46
+ def __init__(
47
+ self,
48
+ dim,
49
+ context_dim=None,
50
+ num_heads=8,
51
+ qkv_bias=False,
52
+ qk_scale=None,
53
+ qk_norm=None,
54
+ attn_drop=0.,
55
+ proj_drop=0.,
56
+ rope_mode='none'
57
+ ):
58
+ super().__init__()
59
+ self.num_heads = num_heads
60
+ head_dim = dim // num_heads
61
+ self.scale = qk_scale or head_dim**-0.5
62
+
63
+ if context_dim is None:
64
+ self.cross_attn = False
65
+ else:
66
+ self.cross_attn = True
67
+
68
+ context_dim = dim if context_dim is None else context_dim
69
+
70
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
71
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
72
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
73
+
74
+ if qk_norm is None:
75
+ self.norm_q = nn.Identity()
76
+ self.norm_k = nn.Identity()
77
+ elif qk_norm == 'layernorm':
78
+ self.norm_q = nn.LayerNorm(head_dim)
79
+ self.norm_k = nn.LayerNorm(head_dim)
80
+ elif qk_norm == 'rmsnorm':
81
+ self.norm_q = RMSNorm(head_dim)
82
+ self.norm_k = RMSNorm(head_dim)
83
+ else:
84
+ raise NotImplementedError
85
+
86
+ self.attn_drop_p = attn_drop
87
+ self.attn_drop = nn.Dropout(attn_drop)
88
+ self.proj = nn.Linear(dim, dim)
89
+ self.proj_drop = nn.Dropout(proj_drop)
90
+
91
+ if self.cross_attn:
92
+ assert rope_mode == 'none'
93
+ self.rope_mode = rope_mode
94
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
95
+ self.rotary = RotaryEmbedding(dim=head_dim)
96
+ elif self.rope_mode == 'dual':
97
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
98
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
99
+
100
+ def _rotary(self, q, k, extras):
101
+ if self.rope_mode == 'shared':
102
+ q, k = self.rotary(q=q, k=k)
103
+ elif self.rope_mode == 'x_only':
104
+ q_x, k_x = self.rotary(
105
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
106
+ )
107
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
108
+ q = torch.cat((q_c, q_x), dim=2)
109
+ k = torch.cat((k_c, k_x), dim=2)
110
+ elif self.rope_mode == 'dual':
111
+ q_x, k_x = self.rotary_x(
112
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
113
+ )
114
+ q_c, k_c = self.rotary_c(
115
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
116
+ )
117
+ q = torch.cat((q_c, q_x), dim=2)
118
+ k = torch.cat((k_c, k_x), dim=2)
119
+ elif self.rope_mode == 'none':
120
+ pass
121
+ else:
122
+ raise NotImplementedError
123
+ return q, k
124
+
125
+ def _attn(self, q, k, v, mask_binary):
126
+ if ATTENTION_MODE == 'flash':
127
+ x = F.scaled_dot_product_attention(
128
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
129
+ )
130
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
131
+ elif ATTENTION_MODE == 'math':
132
+ attn = (q @ k.transpose(-2, -1)) * self.scale
133
+ attn = add_mask(
134
+ attn, mask_binary
135
+ ) if mask_binary is not None else attn
136
+ attn = attn.softmax(dim=-1)
137
+ attn = self.attn_drop(attn)
138
+ x = (attn @ v).transpose(1, 2)
139
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
140
+ else:
141
+ raise NotImplementedError
142
+ return x
143
+
144
+ def forward(self, x, context=None, context_mask=None, extras=0):
145
+ B, L, C = x.shape
146
+ if context is None:
147
+ context = x
148
+
149
+ q = self.to_q(x)
150
+ k = self.to_k(context)
151
+ v = self.to_v(context)
152
+
153
+ if context_mask is not None:
154
+ mask_binary = create_mask(
155
+ x.shape, context.shape, x.device, None, context_mask
156
+ )
157
+ else:
158
+ mask_binary = None
159
+
160
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
161
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
162
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
163
+
164
+ q = self.norm_q(q)
165
+ k = self.norm_k(k)
166
+
167
+ q, k = self._rotary(q, k, extras)
168
+
169
+ x = self._attn(q, k, v, mask_binary)
170
+
171
+ x = self.proj(x)
172
+ x = self.proj_drop(x)
173
+ return x
174
+
175
+
176
+ class JointAttention(nn.Module):
177
+ def __init__(
178
+ self,
179
+ dim,
180
+ num_heads=8,
181
+ qkv_bias=False,
182
+ qk_scale=None,
183
+ qk_norm=None,
184
+ attn_drop=0.,
185
+ proj_drop=0.,
186
+ rope_mode='none'
187
+ ):
188
+ super().__init__()
189
+ self.num_heads = num_heads
190
+ head_dim = dim // num_heads
191
+ self.scale = qk_scale or head_dim**-0.5
192
+
193
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(
194
+ dim, qkv_bias
195
+ )
196
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(
197
+ dim, qkv_bias
198
+ )
199
+
200
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
201
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
202
+
203
+ self.attn_drop_p = attn_drop
204
+ self.attn_drop = nn.Dropout(attn_drop)
205
+
206
+ self.proj_x = nn.Linear(dim, dim)
207
+ self.proj_drop_x = nn.Dropout(proj_drop)
208
+
209
+ self.proj_c = nn.Linear(dim, dim)
210
+ self.proj_drop_c = nn.Dropout(proj_drop)
211
+
212
+ self.rope_mode = rope_mode
213
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
214
+ self.rotary = RotaryEmbedding(dim=head_dim)
215
+ elif self.rope_mode == 'dual':
216
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
217
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
218
+
219
+ def _make_qkv_layers(self, dim, qkv_bias):
220
+ return (
221
+ nn.Linear(dim, dim,
222
+ bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias),
223
+ nn.Linear(dim, dim, bias=qkv_bias)
224
+ )
225
+
226
+ def _make_norm_layers(self, qk_norm, head_dim):
227
+ if qk_norm is None:
228
+ norm_q = nn.Identity()
229
+ norm_k = nn.Identity()
230
+ elif qk_norm == 'layernorm':
231
+ norm_q = nn.LayerNorm(head_dim)
232
+ norm_k = nn.LayerNorm(head_dim)
233
+ elif qk_norm == 'rmsnorm':
234
+ norm_q = RMSNorm(head_dim)
235
+ norm_k = RMSNorm(head_dim)
236
+ else:
237
+ raise NotImplementedError
238
+ return norm_q, norm_k
239
+
240
+ def _rotary(self, q, k, extras):
241
+ if self.rope_mode == 'shared':
242
+ q, k = self.rotary(q=q, k=k)
243
+ elif self.rope_mode == 'x_only':
244
+ q_x, k_x = self.rotary(
245
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
246
+ )
247
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
248
+ q = torch.cat((q_c, q_x), dim=2)
249
+ k = torch.cat((k_c, k_x), dim=2)
250
+ elif self.rope_mode == 'dual':
251
+ q_x, k_x = self.rotary_x(
252
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
253
+ )
254
+ q_c, k_c = self.rotary_c(
255
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
256
+ )
257
+ q = torch.cat((q_c, q_x), dim=2)
258
+ k = torch.cat((k_c, k_x), dim=2)
259
+ elif self.rope_mode == 'none':
260
+ pass
261
+ else:
262
+ raise NotImplementedError
263
+ return q, k
264
+
265
+ def _attn(self, q, k, v, mask_binary):
266
+ if ATTENTION_MODE == 'flash':
267
+ x = F.scaled_dot_product_attention(
268
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
269
+ )
270
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
271
+ elif ATTENTION_MODE == 'math':
272
+ attn = (q @ k.transpose(-2, -1)) * self.scale
273
+ attn = add_mask(
274
+ attn, mask_binary
275
+ ) if mask_binary is not None else attn
276
+ attn = attn.softmax(dim=-1)
277
+ attn = self.attn_drop(attn)
278
+ x = (attn @ v).transpose(1, 2)
279
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
280
+ else:
281
+ raise NotImplementedError
282
+ return x
283
+
284
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
285
+ B = x.shape[0]
286
+ if x_mask is None:
287
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
288
+ if context_mask is None:
289
+ context_mask = torch.ones(
290
+ B, context.shape[-2], device=context.device
291
+ ).bool()
292
+ mask = torch.cat([context_mask, x_mask], dim=1)
293
+ return mask
294
+
295
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
296
+ B, Lx, C = x.shape
297
+ _, Lc, _ = context.shape
298
+ if x_mask is not None or context_mask is not None:
299
+ mask = self._cat_mask(
300
+ x, context, x_mask=x_mask, context_mask=context_mask
301
+ )
302
+ shape = [B, Lx + Lc, C]
303
+ mask_binary = create_mask(
304
+ q_shape=shape,
305
+ k_shape=shape,
306
+ device=x.device,
307
+ q_mask=None,
308
+ k_mask=mask
309
+ )
310
+ else:
311
+ mask_binary = None
312
+
313
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
314
+ qc, kc, vc = self.to_qc(context), self.to_kc(context
315
+ ), self.to_vc(context)
316
+
317
+ qx, kx, vx = map(
318
+ lambda t: einops.
319
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
320
+ [qx, kx, vx]
321
+ )
322
+ qc, kc, vc = map(
323
+ lambda t: einops.
324
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
325
+ [qc, kc, vc]
326
+ )
327
+
328
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
329
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
330
+
331
+ q, k, v = (
332
+ torch.cat([qc, qx],
333
+ dim=2), torch.cat([kc, kx],
334
+ dim=2), torch.cat([vc, vx], dim=2)
335
+ )
336
+
337
+ q, k = self._rotary(q, k, extras)
338
+
339
+ x = self._attn(q, k, v, mask_binary)
340
+
341
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
342
+
343
+ x = self.proj_x(x)
344
+ x = self.proj_drop_x(x)
345
+
346
+ context = self.proj_c(context)
347
+ context = self.proj_drop_c(context)
348
+
349
+ return x, context
models/dit/audio_diffsingernet_dit.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .mask_dit import DiTBlock, FinalBlock, UDiT
6
+ from .modules import (
7
+ film_modulate,
8
+ PatchEmbed,
9
+ PE_wrapper,
10
+ TimestepEmbedder,
11
+ RMSNorm,
12
+ )
13
+
14
+
15
+ class AudioDiTBlock(DiTBlock):
16
+ """
17
+ A modified DiT block with time_aligned_context add to latent.
18
+ """
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ time_aligned_context_dim,
23
+ dilation,
24
+ context_dim=None,
25
+ num_heads=8,
26
+ mlp_ratio=4.,
27
+ qkv_bias=False,
28
+ qk_scale=None,
29
+ qk_norm=None,
30
+ act_layer='gelu',
31
+ norm_layer=nn.LayerNorm,
32
+ time_fusion='none',
33
+ ada_sola_rank=None,
34
+ ada_sola_alpha=None,
35
+ skip=False,
36
+ skip_norm=False,
37
+ rope_mode='none',
38
+ context_norm=False,
39
+ use_checkpoint=False
40
+ ):
41
+ super().__init__(
42
+ dim=dim,
43
+ context_dim=context_dim,
44
+ num_heads=num_heads,
45
+ mlp_ratio=mlp_ratio,
46
+ qkv_bias=qkv_bias,
47
+ qk_scale=qk_scale,
48
+ qk_norm=qk_norm,
49
+ act_layer=act_layer,
50
+ norm_layer=norm_layer,
51
+ time_fusion=time_fusion,
52
+ ada_sola_rank=ada_sola_rank,
53
+ ada_sola_alpha=ada_sola_alpha,
54
+ skip=skip,
55
+ skip_norm=skip_norm,
56
+ rope_mode=rope_mode,
57
+ context_norm=context_norm,
58
+ use_checkpoint=use_checkpoint
59
+ )
60
+ # time-aligned context projection
61
+ self.ta_context_projection = nn.Linear(
62
+ time_aligned_context_dim, 2 * dim
63
+ )
64
+ self.dilated_conv = nn.Conv1d(
65
+ dim, 2 * dim, kernel_size=3, padding=dilation, dilation=dilation
66
+ )
67
+
68
+ def forward(
69
+ self,
70
+ x,
71
+ time_aligned_context,
72
+ time_token=None,
73
+ time_ada=None,
74
+ skip=None,
75
+ context=None,
76
+ x_mask=None,
77
+ context_mask=None,
78
+ extras=None
79
+ ):
80
+ if self.use_checkpoint:
81
+ return checkpoint(
82
+ self._forward,
83
+ x,
84
+ time_aligned_context,
85
+ time_token,
86
+ time_ada,
87
+ skip,
88
+ context,
89
+ x_mask,
90
+ context_mask,
91
+ extras,
92
+ use_reentrant=False
93
+ )
94
+ else:
95
+ return self._forward(
96
+ x,
97
+ time_aligned_context,
98
+ time_token,
99
+ time_ada,
100
+ skip,
101
+ context,
102
+ x_mask,
103
+ context_mask,
104
+ extras,
105
+ )
106
+
107
+ def _forward(
108
+ self,
109
+ x,
110
+ time_aligned_context,
111
+ time_token=None,
112
+ time_ada=None,
113
+ skip=None,
114
+ context=None,
115
+ x_mask=None,
116
+ context_mask=None,
117
+ extras=None
118
+ ):
119
+ B, T, C = x.shape
120
+ if self.skip_linear is not None:
121
+ assert skip is not None
122
+ cat = torch.cat([x, skip], dim=-1)
123
+ cat = self.skip_norm(cat)
124
+ x = self.skip_linear(cat)
125
+
126
+ if self.use_adanorm:
127
+ time_ada = self.adaln(time_token, time_ada)
128
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
129
+ gate_mlp) = time_ada.chunk(6, dim=1)
130
+
131
+ # self attention
132
+ if self.use_adanorm:
133
+ x_norm = film_modulate(
134
+ self.norm1(x), shift=shift_msa, scale=scale_msa
135
+ )
136
+ x = x + (1-gate_msa) * self.attn(
137
+ x_norm, context=None, context_mask=x_mask, extras=extras
138
+ )
139
+ else:
140
+ # TODO diffusion timestep input is not fused here
141
+ x = x + self.attn(
142
+ self.norm1(x),
143
+ context=None,
144
+ context_mask=x_mask,
145
+ extras=extras
146
+ )
147
+
148
+ # time-aligned context
149
+ time_aligned_context = self.ta_context_projection(time_aligned_context)
150
+ x = self.dilated_conv(x.transpose(1, 2)
151
+ ).transpose(1, 2) + time_aligned_context
152
+
153
+ gate, filter = torch.chunk(x, 2, dim=-1)
154
+ x = torch.sigmoid(gate) * torch.tanh(filter)
155
+
156
+ # cross attention
157
+ if self.use_context:
158
+ assert context is not None
159
+ x = x + self.cross_attn(
160
+ x=self.norm2(x),
161
+ context=self.norm_context(context),
162
+ context_mask=context_mask,
163
+ extras=extras
164
+ )
165
+
166
+ # mlp
167
+ if self.use_adanorm:
168
+ x_norm = film_modulate(
169
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
170
+ )
171
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
172
+ else:
173
+ x = x + self.mlp(self.norm3(x))
174
+
175
+ return x
176
+
177
+
178
+ class AudioUDiT(UDiT):
179
+ def __init__(
180
+ self,
181
+ img_size=224,
182
+ patch_size=16,
183
+ in_chans=3,
184
+ input_type='2d',
185
+ out_chans=None,
186
+ embed_dim=768,
187
+ depth=12,
188
+ dilation_cycle_length=4,
189
+ num_heads=12,
190
+ mlp_ratio=4,
191
+ qkv_bias=False,
192
+ qk_scale=None,
193
+ qk_norm=None,
194
+ act_layer='gelu',
195
+ norm_layer='layernorm',
196
+ context_norm=False,
197
+ use_checkpoint=False,
198
+ time_fusion='token',
199
+ ada_sola_rank=None,
200
+ ada_sola_alpha=None,
201
+ cls_dim=None,
202
+ time_aligned_context_dim=768,
203
+ context_dim=768,
204
+ context_fusion='concat',
205
+ context_max_length=128,
206
+ context_pe_method='sinu',
207
+ pe_method='abs',
208
+ rope_mode='none',
209
+ use_conv=True,
210
+ skip=True,
211
+ skip_norm=True
212
+ ):
213
+ nn.Module.__init__(self)
214
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
215
+
216
+ # input
217
+ self.in_chans = in_chans
218
+ self.input_type = input_type
219
+ if self.input_type == '2d':
220
+ num_patches = (img_size[0] //
221
+ patch_size) * (img_size[1] // patch_size)
222
+ elif self.input_type == '1d':
223
+ num_patches = img_size // patch_size
224
+ self.patch_embed = PatchEmbed(
225
+ patch_size=patch_size,
226
+ in_chans=in_chans,
227
+ embed_dim=embed_dim,
228
+ input_type=input_type
229
+ )
230
+ out_chans = in_chans if out_chans is None else out_chans
231
+ self.out_chans = out_chans
232
+
233
+ # position embedding
234
+ self.rope = rope_mode
235
+ self.x_pe = PE_wrapper(
236
+ dim=embed_dim, method=pe_method, length=num_patches
237
+ )
238
+
239
+ # time embed
240
+ self.time_embed = TimestepEmbedder(embed_dim)
241
+ self.time_fusion = time_fusion
242
+ self.use_adanorm = False
243
+
244
+ # cls embed
245
+ if cls_dim is not None:
246
+ self.cls_embed = nn.Sequential(
247
+ nn.Linear(cls_dim, embed_dim, bias=True),
248
+ nn.SiLU(),
249
+ nn.Linear(embed_dim, embed_dim, bias=True),
250
+ )
251
+ else:
252
+ self.cls_embed = None
253
+
254
+ # time fusion
255
+ if time_fusion == 'token':
256
+ # put token at the beginning of sequence
257
+ self.extras = 2 if self.cls_embed else 1
258
+ self.time_pe = PE_wrapper(
259
+ dim=embed_dim, method='abs', length=self.extras
260
+ )
261
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
262
+ self.use_adanorm = True
263
+ # aviod repetitive silu for each adaln block
264
+ self.time_act = nn.SiLU()
265
+ self.extras = 0
266
+ self.time_ada_final = nn.Linear(
267
+ embed_dim, 2 * embed_dim, bias=True
268
+ )
269
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
270
+ # shared adaln
271
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
272
+ else:
273
+ self.time_ada = None
274
+ else:
275
+ raise NotImplementedError
276
+
277
+ # context
278
+ # use a simple projection
279
+ self.use_context = False
280
+ self.context_cross = False
281
+ self.context_max_length = context_max_length
282
+ self.context_fusion = 'none'
283
+ if context_dim is not None:
284
+ self.use_context = True
285
+ self.context_embed = nn.Sequential(
286
+ nn.Linear(context_dim, embed_dim, bias=True),
287
+ nn.SiLU(),
288
+ nn.Linear(embed_dim, embed_dim, bias=True),
289
+ )
290
+ self.context_fusion = context_fusion
291
+ if context_fusion == 'concat' or context_fusion == 'joint':
292
+ self.extras += context_max_length
293
+ self.context_pe = PE_wrapper(
294
+ dim=embed_dim,
295
+ method=context_pe_method,
296
+ length=context_max_length
297
+ )
298
+ # no cross attention layers
299
+ context_dim = None
300
+ elif context_fusion == 'cross':
301
+ self.context_pe = PE_wrapper(
302
+ dim=embed_dim,
303
+ method=context_pe_method,
304
+ length=context_max_length
305
+ )
306
+ self.context_cross = True
307
+ context_dim = embed_dim
308
+ else:
309
+ raise NotImplementedError
310
+
311
+ self.use_skip = skip
312
+
313
+ # norm layers
314
+ if norm_layer == 'layernorm':
315
+ norm_layer = nn.LayerNorm
316
+ elif norm_layer == 'rmsnorm':
317
+ norm_layer = RMSNorm
318
+ else:
319
+ raise NotImplementedError
320
+
321
+ self.in_blocks = nn.ModuleList([
322
+ AudioDiTBlock(
323
+ dim=embed_dim,
324
+ time_aligned_context_dim=time_aligned_context_dim,
325
+ dilation=2**(i % dilation_cycle_length),
326
+ context_dim=context_dim,
327
+ num_heads=num_heads,
328
+ mlp_ratio=mlp_ratio,
329
+ qkv_bias=qkv_bias,
330
+ qk_scale=qk_scale,
331
+ qk_norm=qk_norm,
332
+ act_layer=act_layer,
333
+ norm_layer=norm_layer,
334
+ time_fusion=time_fusion,
335
+ ada_sola_rank=ada_sola_rank,
336
+ ada_sola_alpha=ada_sola_alpha,
337
+ skip=False,
338
+ skip_norm=False,
339
+ rope_mode=self.rope,
340
+ context_norm=context_norm,
341
+ use_checkpoint=use_checkpoint
342
+ ) for i in range(depth // 2)
343
+ ])
344
+
345
+ self.mid_block = AudioDiTBlock(
346
+ dim=embed_dim,
347
+ time_aligned_context_dim=time_aligned_context_dim,
348
+ dilation=1,
349
+ context_dim=context_dim,
350
+ num_heads=num_heads,
351
+ mlp_ratio=mlp_ratio,
352
+ qkv_bias=qkv_bias,
353
+ qk_scale=qk_scale,
354
+ qk_norm=qk_norm,
355
+ act_layer=act_layer,
356
+ norm_layer=norm_layer,
357
+ time_fusion=time_fusion,
358
+ ada_sola_rank=ada_sola_rank,
359
+ ada_sola_alpha=ada_sola_alpha,
360
+ skip=False,
361
+ skip_norm=False,
362
+ rope_mode=self.rope,
363
+ context_norm=context_norm,
364
+ use_checkpoint=use_checkpoint
365
+ )
366
+
367
+ self.out_blocks = nn.ModuleList([
368
+ AudioDiTBlock(
369
+ dim=embed_dim,
370
+ time_aligned_context_dim=time_aligned_context_dim,
371
+ dilation=2**(i % dilation_cycle_length),
372
+ context_dim=context_dim,
373
+ num_heads=num_heads,
374
+ mlp_ratio=mlp_ratio,
375
+ qkv_bias=qkv_bias,
376
+ qk_scale=qk_scale,
377
+ qk_norm=qk_norm,
378
+ act_layer=act_layer,
379
+ norm_layer=norm_layer,
380
+ time_fusion=time_fusion,
381
+ ada_sola_rank=ada_sola_rank,
382
+ ada_sola_alpha=ada_sola_alpha,
383
+ skip=skip,
384
+ skip_norm=skip_norm,
385
+ rope_mode=self.rope,
386
+ context_norm=context_norm,
387
+ use_checkpoint=use_checkpoint
388
+ ) for i in range(depth // 2)
389
+ ])
390
+
391
+ # FinalLayer block
392
+ self.use_conv = use_conv
393
+ self.final_block = FinalBlock(
394
+ embed_dim=embed_dim,
395
+ patch_size=patch_size,
396
+ img_size=img_size,
397
+ in_chans=out_chans,
398
+ input_type=input_type,
399
+ norm_layer=norm_layer,
400
+ use_conv=use_conv,
401
+ use_adanorm=self.use_adanorm
402
+ )
403
+ self.initialize_weights()
404
+
405
+ def forward(
406
+ self,
407
+ x,
408
+ timesteps,
409
+ time_aligned_context,
410
+ context,
411
+ x_mask=None,
412
+ context_mask=None,
413
+ cls_token=None,
414
+ controlnet_skips=None,
415
+ ):
416
+ # make it compatible with int time step during inference
417
+ if timesteps.dim() == 0:
418
+ timesteps = timesteps.expand(x.shape[0]
419
+ ).to(x.device, dtype=torch.long)
420
+
421
+ x = self.patch_embed(x)
422
+ x = self.x_pe(x)
423
+
424
+ B, L, D = x.shape
425
+
426
+ if self.use_context:
427
+ context_token = self.context_embed(context)
428
+ context_token = self.context_pe(context_token)
429
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
430
+ x, x_mask = self._concat_x_context(
431
+ x=x,
432
+ context=context_token,
433
+ x_mask=x_mask,
434
+ context_mask=context_mask
435
+ )
436
+ context_token, context_mask = None, None
437
+ else:
438
+ context_token, context_mask = None, None
439
+
440
+ time_token = self.time_embed(timesteps)
441
+ if self.cls_embed:
442
+ cls_token = self.cls_embed(cls_token)
443
+ time_ada = None
444
+ time_ada_final = None
445
+ if self.use_adanorm:
446
+ if self.cls_embed:
447
+ time_token = time_token + cls_token
448
+ time_token = self.time_act(time_token)
449
+ time_ada_final = self.time_ada_final(time_token)
450
+ if self.time_ada is not None:
451
+ time_ada = self.time_ada(time_token)
452
+ else:
453
+ time_token = time_token.unsqueeze(dim=1)
454
+ if self.cls_embed:
455
+ cls_token = cls_token.unsqueeze(dim=1)
456
+ time_token = torch.cat([time_token, cls_token], dim=1)
457
+ time_token = self.time_pe(time_token)
458
+ x = torch.cat((time_token, x), dim=1)
459
+ if x_mask is not None:
460
+ x_mask = torch.cat([
461
+ torch.ones(B, time_token.shape[1],
462
+ device=x_mask.device).bool(), x_mask
463
+ ],
464
+ dim=1)
465
+ time_token = None
466
+
467
+ skips = []
468
+ for blk in self.in_blocks:
469
+ x = blk(
470
+ x=x,
471
+ time_aligned_context=time_aligned_context,
472
+ time_token=time_token,
473
+ time_ada=time_ada,
474
+ skip=None,
475
+ context=context_token,
476
+ x_mask=x_mask,
477
+ context_mask=context_mask,
478
+ extras=self.extras
479
+ )
480
+ if self.use_skip:
481
+ skips.append(x)
482
+
483
+ x = self.mid_block(
484
+ x=x,
485
+ time_aligned_context=time_aligned_context,
486
+ time_token=time_token,
487
+ time_ada=time_ada,
488
+ skip=None,
489
+ context=context_token,
490
+ x_mask=x_mask,
491
+ context_mask=context_mask,
492
+ extras=self.extras
493
+ )
494
+ for blk in self.out_blocks:
495
+ if self.use_skip:
496
+ skip = skips.pop()
497
+ if controlnet_skips:
498
+ # add to skip like u-net controlnet
499
+ skip = skip + controlnet_skips.pop()
500
+ else:
501
+ skip = None
502
+ if controlnet_skips:
503
+ # directly add to x
504
+ x = x + controlnet_skips.pop()
505
+
506
+ x = blk(
507
+ x=x,
508
+ time_aligned_context=time_aligned_context,
509
+ time_token=time_token,
510
+ time_ada=time_ada,
511
+ skip=skip,
512
+ context=context_token,
513
+ x_mask=x_mask,
514
+ context_mask=context_mask,
515
+ extras=self.extras
516
+ )
517
+
518
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
519
+
520
+ return x
models/dit/audio_dit.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .mask_dit import DiTBlock, FinalBlock, UDiT
6
+ from .modules import (
7
+ film_modulate,
8
+ PatchEmbed,
9
+ PE_wrapper,
10
+ TimestepEmbedder,
11
+ RMSNorm,
12
+ )
13
+
14
+
15
+ class LayerFusionDiTBlock(DiTBlock):
16
+ """
17
+ A modified DiT block with time aligned context add to latent.
18
+ """
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ ta_context_dim,
23
+ ta_context_norm=False,
24
+ context_dim=None,
25
+ num_heads=8,
26
+ mlp_ratio=4.,
27
+ qkv_bias=False,
28
+ qk_scale=None,
29
+ qk_norm=None,
30
+ act_layer='gelu',
31
+ norm_layer=nn.LayerNorm,
32
+ ta_context_fusion='add',
33
+ time_fusion='none',
34
+ ada_sola_rank=None,
35
+ ada_sola_alpha=None,
36
+ skip=False,
37
+ skip_norm=False,
38
+ rope_mode='none',
39
+ context_norm=False,
40
+ use_checkpoint=False
41
+ ):
42
+ super().__init__(
43
+ dim=dim,
44
+ context_dim=context_dim,
45
+ num_heads=num_heads,
46
+ mlp_ratio=mlp_ratio,
47
+ qkv_bias=qkv_bias,
48
+ qk_scale=qk_scale,
49
+ qk_norm=qk_norm,
50
+ act_layer=act_layer,
51
+ norm_layer=norm_layer,
52
+ time_fusion=time_fusion,
53
+ ada_sola_rank=ada_sola_rank,
54
+ ada_sola_alpha=ada_sola_alpha,
55
+ skip=skip,
56
+ skip_norm=skip_norm,
57
+ rope_mode=rope_mode,
58
+ context_norm=context_norm,
59
+ use_checkpoint=use_checkpoint
60
+ )
61
+ self.ta_context_fusion = ta_context_fusion
62
+ self.ta_context_norm = ta_context_norm
63
+ if self.ta_context_fusion == "add":
64
+ self.ta_context_projection = nn.Linear(
65
+ ta_context_dim, dim, bias=False
66
+ )
67
+ self.ta_context_norm = norm_layer(
68
+ ta_context_dim
69
+ ) if self.ta_context_norm else nn.Identity()
70
+ elif self.ta_context_fusion == "concat":
71
+ self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim)
72
+ self.ta_context_norm = norm_layer(
73
+ ta_context_dim + dim
74
+ ) if self.ta_context_norm else nn.Identity()
75
+
76
+ def forward(
77
+ self,
78
+ x,
79
+ time_aligned_context,
80
+ time_token=None,
81
+ time_ada=None,
82
+ skip=None,
83
+ context=None,
84
+ x_mask=None,
85
+ context_mask=None,
86
+ extras=None
87
+ ):
88
+ if self.use_checkpoint:
89
+ return checkpoint(
90
+ self._forward,
91
+ x,
92
+ time_aligned_context,
93
+ time_token,
94
+ time_ada,
95
+ skip,
96
+ context,
97
+ x_mask,
98
+ context_mask,
99
+ extras,
100
+ use_reentrant=False
101
+ )
102
+ else:
103
+ return self._forward(
104
+ x,
105
+ time_aligned_context,
106
+ time_token,
107
+ time_ada,
108
+ skip,
109
+ context,
110
+ x_mask,
111
+ context_mask,
112
+ extras,
113
+ )
114
+
115
+ def _forward(
116
+ self,
117
+ x,
118
+ time_aligned_context,
119
+ time_token=None,
120
+ time_ada=None,
121
+ skip=None,
122
+ context=None,
123
+ x_mask=None,
124
+ context_mask=None,
125
+ extras=None
126
+ ):
127
+ B, T, C = x.shape
128
+
129
+ # skip connection
130
+ if self.skip_linear is not None:
131
+ assert skip is not None
132
+ cat = torch.cat([x, skip], dim=-1)
133
+ cat = self.skip_norm(cat)
134
+ x = self.skip_linear(cat)
135
+
136
+ if self.use_adanorm:
137
+ time_ada = self.adaln(time_token, time_ada)
138
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
139
+ gate_mlp) = time_ada.chunk(6, dim=1)
140
+
141
+ # self attention
142
+ if self.use_adanorm:
143
+ x_norm = film_modulate(
144
+ self.norm1(x), shift=shift_msa, scale=scale_msa
145
+ )
146
+ tanh_gate_msa = torch.tanh(1 - gate_msa)
147
+ x = x + tanh_gate_msa * self.attn(
148
+ x_norm, context=None, context_mask=x_mask, extras=extras
149
+ )
150
+ # x = x + (1 - gate_msa) * self.attn(
151
+ # x_norm, context=None, context_mask=x_mask, extras=extras
152
+ # )
153
+ else:
154
+ # TODO diffusion timestep input is not fused here
155
+ x = x + self.attn(
156
+ self.norm1(x),
157
+ context=None,
158
+ context_mask=x_mask,
159
+ extras=extras
160
+ )
161
+
162
+ # time aligned context fusion
163
+ if self.ta_context_fusion == "add":
164
+ time_aligned_context = self.ta_context_projection(
165
+ self.ta_context_norm(time_aligned_context)
166
+ )
167
+ if time_aligned_context.size(1) < x.size(1):
168
+ time_aligned_context = nn.functional.pad(
169
+ time_aligned_context, (0, 0, 1, 0)
170
+ )
171
+ x = x + time_aligned_context
172
+ elif self.ta_context_fusion == "concat":
173
+ if time_aligned_context.size(1) < x.size(1):
174
+ time_aligned_context = nn.functional.pad(
175
+ time_aligned_context, (0, 0, 1, 0)
176
+ )
177
+ cat = torch.cat([x, time_aligned_context], dim=-1)
178
+ cat = self.ta_context_norm(cat)
179
+ x = self.ta_context_projection(cat)
180
+
181
+ # cross attention
182
+ if self.use_context:
183
+ assert context is not None
184
+ x = x + self.cross_attn(
185
+ x=self.norm2(x),
186
+ context=self.norm_context(context),
187
+ context_mask=context_mask,
188
+ extras=extras
189
+ )
190
+
191
+ # mlp
192
+ if self.use_adanorm:
193
+ x_norm = film_modulate(
194
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
195
+ )
196
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
197
+ else:
198
+ x = x + self.mlp(self.norm3(x))
199
+
200
+ return x
201
+
202
+
203
+ class LayerFusionAudioDiT(UDiT):
204
+ def __init__(
205
+ self,
206
+ img_size=224,
207
+ patch_size=16,
208
+ in_chans=3,
209
+ input_type='2d',
210
+ out_chans=None,
211
+ embed_dim=768,
212
+ depth=12,
213
+ num_heads=12,
214
+ mlp_ratio=4,
215
+ qkv_bias=False,
216
+ qk_scale=None,
217
+ qk_norm=None,
218
+ act_layer='gelu',
219
+ norm_layer='layernorm',
220
+ context_norm=False,
221
+ use_checkpoint=False,
222
+ time_fusion='token',
223
+ ada_sola_rank=None,
224
+ ada_sola_alpha=None,
225
+ cls_dim=None,
226
+ ta_context_dim=768,
227
+ ta_context_fusion='concat',
228
+ ta_context_norm=True,
229
+ context_dim=768,
230
+ context_fusion='concat',
231
+ context_max_length=128,
232
+ context_pe_method='sinu',
233
+ pe_method='abs',
234
+ rope_mode='none',
235
+ use_conv=True,
236
+ skip=True,
237
+ skip_norm=True
238
+ ):
239
+ nn.Module.__init__(self)
240
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
241
+
242
+ # input
243
+ self.in_chans = in_chans
244
+ self.input_type = input_type
245
+ if self.input_type == '2d':
246
+ num_patches = (img_size[0] //
247
+ patch_size) * (img_size[1] // patch_size)
248
+ elif self.input_type == '1d':
249
+ num_patches = img_size // patch_size
250
+ self.patch_embed = PatchEmbed(
251
+ patch_size=patch_size,
252
+ in_chans=in_chans,
253
+ embed_dim=embed_dim,
254
+ input_type=input_type
255
+ )
256
+ out_chans = in_chans if out_chans is None else out_chans
257
+ self.out_chans = out_chans
258
+
259
+ # position embedding
260
+ self.rope = rope_mode
261
+ self.x_pe = PE_wrapper(
262
+ dim=embed_dim, method=pe_method, length=num_patches
263
+ )
264
+
265
+ # time embed
266
+ self.time_embed = TimestepEmbedder(embed_dim)
267
+ self.time_fusion = time_fusion
268
+ self.use_adanorm = False
269
+
270
+ # cls embed
271
+ if cls_dim is not None:
272
+ self.cls_embed = nn.Sequential(
273
+ nn.Linear(cls_dim, embed_dim, bias=True),
274
+ nn.SiLU(),
275
+ nn.Linear(embed_dim, embed_dim, bias=True),
276
+ )
277
+ else:
278
+ self.cls_embed = None
279
+
280
+ # time fusion
281
+ if time_fusion == 'token':
282
+ # put token at the beginning of sequence
283
+ self.extras = 2 if self.cls_embed else 1
284
+ self.time_pe = PE_wrapper(
285
+ dim=embed_dim, method='abs', length=self.extras
286
+ )
287
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
288
+ self.use_adanorm = True
289
+ # aviod repetitive silu for each adaln block
290
+ self.time_act = nn.SiLU()
291
+ self.extras = 0
292
+ self.time_ada_final = nn.Linear(
293
+ embed_dim, 2 * embed_dim, bias=True
294
+ )
295
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
296
+ # shared adaln
297
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
298
+ else:
299
+ self.time_ada = None
300
+ else:
301
+ raise NotImplementedError
302
+
303
+ # context
304
+ # use a simple projection
305
+ self.use_context = False
306
+ self.context_cross = False
307
+ self.context_max_length = context_max_length
308
+ self.context_fusion = 'none'
309
+ if context_dim is not None:
310
+ self.use_context = True
311
+ self.context_embed = nn.Sequential(
312
+ nn.Linear(context_dim, embed_dim, bias=True),
313
+ nn.SiLU(),
314
+ nn.Linear(embed_dim, embed_dim, bias=True),
315
+ )
316
+ self.context_fusion = context_fusion
317
+ if context_fusion == 'concat' or context_fusion == 'joint':
318
+ self.extras += context_max_length
319
+ self.context_pe = PE_wrapper(
320
+ dim=embed_dim,
321
+ method=context_pe_method,
322
+ length=context_max_length
323
+ )
324
+ # no cross attention layers
325
+ context_dim = None
326
+ elif context_fusion == 'cross':
327
+ self.context_pe = PE_wrapper(
328
+ dim=embed_dim,
329
+ method=context_pe_method,
330
+ length=context_max_length
331
+ )
332
+ self.context_cross = True
333
+ context_dim = embed_dim
334
+ else:
335
+ raise NotImplementedError
336
+
337
+ self.use_skip = skip
338
+
339
+ # norm layers
340
+ if norm_layer == 'layernorm':
341
+ norm_layer = nn.LayerNorm
342
+ elif norm_layer == 'rmsnorm':
343
+ norm_layer = RMSNorm
344
+ else:
345
+ raise NotImplementedError
346
+
347
+ self.in_blocks = nn.ModuleList([
348
+ LayerFusionDiTBlock(
349
+ dim=embed_dim,
350
+ ta_context_dim=ta_context_dim,
351
+ ta_context_fusion=ta_context_fusion,
352
+ ta_context_norm=ta_context_norm,
353
+ context_dim=context_dim,
354
+ num_heads=num_heads,
355
+ mlp_ratio=mlp_ratio,
356
+ qkv_bias=qkv_bias,
357
+ qk_scale=qk_scale,
358
+ qk_norm=qk_norm,
359
+ act_layer=act_layer,
360
+ norm_layer=norm_layer,
361
+ time_fusion=time_fusion,
362
+ ada_sola_rank=ada_sola_rank,
363
+ ada_sola_alpha=ada_sola_alpha,
364
+ skip=False,
365
+ skip_norm=False,
366
+ rope_mode=self.rope,
367
+ context_norm=context_norm,
368
+ use_checkpoint=use_checkpoint
369
+ ) for i in range(depth // 2)
370
+ ])
371
+
372
+ self.mid_block = LayerFusionDiTBlock(
373
+ dim=embed_dim,
374
+ ta_context_dim=ta_context_dim,
375
+ context_dim=context_dim,
376
+ num_heads=num_heads,
377
+ mlp_ratio=mlp_ratio,
378
+ qkv_bias=qkv_bias,
379
+ qk_scale=qk_scale,
380
+ qk_norm=qk_norm,
381
+ act_layer=act_layer,
382
+ norm_layer=norm_layer,
383
+ time_fusion=time_fusion,
384
+ ada_sola_rank=ada_sola_rank,
385
+ ada_sola_alpha=ada_sola_alpha,
386
+ ta_context_fusion=ta_context_fusion,
387
+ ta_context_norm=ta_context_norm,
388
+ skip=False,
389
+ skip_norm=False,
390
+ rope_mode=self.rope,
391
+ context_norm=context_norm,
392
+ use_checkpoint=use_checkpoint
393
+ )
394
+
395
+ self.out_blocks = nn.ModuleList([
396
+ LayerFusionDiTBlock(
397
+ dim=embed_dim,
398
+ ta_context_dim=ta_context_dim,
399
+ context_dim=context_dim,
400
+ num_heads=num_heads,
401
+ mlp_ratio=mlp_ratio,
402
+ qkv_bias=qkv_bias,
403
+ qk_scale=qk_scale,
404
+ qk_norm=qk_norm,
405
+ act_layer=act_layer,
406
+ norm_layer=norm_layer,
407
+ time_fusion=time_fusion,
408
+ ada_sola_rank=ada_sola_rank,
409
+ ada_sola_alpha=ada_sola_alpha,
410
+ ta_context_fusion=ta_context_fusion,
411
+ ta_context_norm=ta_context_norm,
412
+ skip=skip,
413
+ skip_norm=skip_norm,
414
+ rope_mode=self.rope,
415
+ context_norm=context_norm,
416
+ use_checkpoint=use_checkpoint
417
+ ) for i in range(depth // 2)
418
+ ])
419
+
420
+ # FinalLayer block
421
+ self.use_conv = use_conv
422
+ self.final_block = FinalBlock(
423
+ embed_dim=embed_dim,
424
+ patch_size=patch_size,
425
+ img_size=img_size,
426
+ in_chans=out_chans,
427
+ input_type=input_type,
428
+ norm_layer=norm_layer,
429
+ use_conv=use_conv,
430
+ use_adanorm=self.use_adanorm
431
+ )
432
+ self.initialize_weights()
433
+
434
+ def forward(
435
+ self,
436
+ x,
437
+ timesteps,
438
+ time_aligned_context,
439
+ context,
440
+ x_mask=None,
441
+ context_mask=None,
442
+ cls_token=None,
443
+ controlnet_skips=None,
444
+ ):
445
+ # make it compatible with int time step during inference
446
+ if timesteps.dim() == 0:
447
+ timesteps = timesteps.expand(x.shape[0]
448
+ ).to(x.device, dtype=torch.long)
449
+
450
+ x = self.patch_embed(x)
451
+ x = self.x_pe(x)
452
+
453
+ B, L, D = x.shape
454
+
455
+ if self.use_context:
456
+ context_token = self.context_embed(context)
457
+ context_token = self.context_pe(context_token)
458
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
459
+ x, x_mask = self._concat_x_context(
460
+ x=x,
461
+ context=context_token,
462
+ x_mask=x_mask,
463
+ context_mask=context_mask
464
+ )
465
+ context_token, context_mask = None, None
466
+ else:
467
+ context_token, context_mask = None, None
468
+
469
+ time_token = self.time_embed(timesteps)
470
+ if self.cls_embed:
471
+ cls_token = self.cls_embed(cls_token)
472
+ time_ada = None
473
+ time_ada_final = None
474
+ if self.use_adanorm:
475
+ if self.cls_embed:
476
+ time_token = time_token + cls_token
477
+ time_token = self.time_act(time_token)
478
+ time_ada_final = self.time_ada_final(time_token)
479
+ if self.time_ada is not None:
480
+ time_ada = self.time_ada(time_token)
481
+ else:
482
+ time_token = time_token.unsqueeze(dim=1)
483
+ if self.cls_embed:
484
+ cls_token = cls_token.unsqueeze(dim=1)
485
+ time_token = torch.cat([time_token, cls_token], dim=1)
486
+ time_token = self.time_pe(time_token)
487
+ x = torch.cat((time_token, x), dim=1)
488
+ if x_mask is not None:
489
+ x_mask = torch.cat([
490
+ torch.ones(B, time_token.shape[1],
491
+ device=x_mask.device).bool(), x_mask
492
+ ],
493
+ dim=1)
494
+ time_token = None
495
+
496
+ skips = []
497
+ for blk_idx, blk in enumerate(self.in_blocks):
498
+ x = blk(
499
+ x=x,
500
+ time_aligned_context=time_aligned_context,
501
+ time_token=time_token,
502
+ time_ada=time_ada,
503
+ skip=None,
504
+ context=context_token,
505
+ x_mask=x_mask,
506
+ context_mask=context_mask,
507
+ extras=self.extras
508
+ )
509
+ # if not self.training:
510
+ # print(
511
+ # f"in block {blk_idx}, min: {x.min().item()}, max: {x.max().item()}, std: {x.std().item()}"
512
+ # )
513
+ if self.use_skip:
514
+ skips.append(x)
515
+
516
+ x = self.mid_block(
517
+ x=x,
518
+ time_aligned_context=time_aligned_context,
519
+ time_token=time_token,
520
+ time_ada=time_ada,
521
+ skip=None,
522
+ context=context_token,
523
+ x_mask=x_mask,
524
+ context_mask=context_mask,
525
+ extras=self.extras
526
+ )
527
+ for blk_idx, blk in enumerate(self.out_blocks):
528
+ if self.use_skip:
529
+ skip = skips.pop()
530
+ if controlnet_skips:
531
+ # add to skip like u-net controlnet
532
+ skip = skip + controlnet_skips.pop()
533
+ else:
534
+ skip = None
535
+ if controlnet_skips:
536
+ # directly add to x
537
+ x = x + controlnet_skips.pop()
538
+
539
+ x = blk(
540
+ x=x,
541
+ time_aligned_context=time_aligned_context,
542
+ time_token=time_token,
543
+ time_ada=time_ada,
544
+ skip=skip,
545
+ context=context_token,
546
+ x_mask=x_mask,
547
+ context_mask=context_mask,
548
+ extras=self.extras
549
+ )
550
+ # if not self.training:
551
+ # print(
552
+ # f"out block {blk_idx}, min: {x.min().item()}, max: {x.max().item()}, std: {x.std().item()}"
553
+ # )
554
+
555
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
556
+
557
+ return x
558
+
559
+
560
+ class InputFusionAudioDiT(UDiT):
561
+ def __init__(
562
+ self,
563
+ img_size=224,
564
+ patch_size=16,
565
+ in_chans=3,
566
+ input_type='2d',
567
+ out_chans=None,
568
+ embed_dim=768,
569
+ depth=12,
570
+ num_heads=12,
571
+ mlp_ratio=4,
572
+ qkv_bias=False,
573
+ qk_scale=None,
574
+ qk_norm=None,
575
+ act_layer='gelu',
576
+ norm_layer='layernorm',
577
+ context_norm=False,
578
+ use_checkpoint=False,
579
+ time_fusion='token',
580
+ ada_sola_rank=None,
581
+ ada_sola_alpha=None,
582
+ cls_dim=None,
583
+ ta_context_dim=768,
584
+ context_dim=768,
585
+ context_fusion='concat',
586
+ context_max_length=128,
587
+ context_pe_method='sinu',
588
+ pe_method='abs',
589
+ rope_mode='none',
590
+ use_conv=True,
591
+ skip=True,
592
+ skip_norm=True
593
+ ):
594
+ super().__init__(
595
+ img_size,
596
+ patch_size,
597
+ in_chans,
598
+ input_type,
599
+ out_chans,
600
+ embed_dim,
601
+ depth,
602
+ num_heads,
603
+ mlp_ratio,
604
+ qkv_bias,
605
+ qk_scale,
606
+ qk_norm,
607
+ act_layer,
608
+ norm_layer,
609
+ context_norm,
610
+ use_checkpoint,
611
+ time_fusion,
612
+ ada_sola_rank,
613
+ ada_sola_alpha,
614
+ cls_dim,
615
+ context_dim,
616
+ context_fusion,
617
+ context_max_length,
618
+ context_pe_method,
619
+ pe_method,
620
+ rope_mode,
621
+ use_conv,
622
+ skip,
623
+ skip_norm,
624
+ )
625
+ self.input_proj = nn.Linear(in_chans + ta_context_dim, in_chans)
626
+ nn.init.xavier_uniform_(self.input_proj.weight)
627
+ nn.init.constant_(self.input_proj.bias, 0)
628
+
629
+ def forward(
630
+ self,
631
+ x,
632
+ timesteps,
633
+ time_aligned_context,
634
+ context,
635
+ x_mask=None,
636
+ context_mask=None,
637
+ cls_token=None,
638
+ controlnet_skips=None
639
+ ):
640
+ x = self.input_proj(
641
+ torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1)
642
+ )
643
+ x = x.transpose(1, 2)
644
+ return super().forward(
645
+ x=x,
646
+ timesteps=timesteps,
647
+ context=context,
648
+ x_mask=x_mask,
649
+ context_mask=context_mask,
650
+ cls_token=cls_token,
651
+ controlnet_skips=controlnet_skips
652
+ )
models/dit/mask_dit.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ from .modules import (
8
+ film_modulate,
9
+ unpatchify,
10
+ PatchEmbed,
11
+ PE_wrapper,
12
+ TimestepEmbedder,
13
+ FeedForward,
14
+ RMSNorm,
15
+ )
16
+ from .span_mask import compute_mask_indices
17
+ from .attention import Attention
18
+
19
+ logger = logging.Logger(__file__)
20
+
21
+
22
+ class AdaLN(nn.Module):
23
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
24
+ super().__init__()
25
+ self.ada_mode = ada_mode
26
+ self.scale_shift_table = None
27
+ if ada_mode == 'ada':
28
+ # move nn.silu outside
29
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
30
+ elif ada_mode == 'ada_single':
31
+ # adaln used in pixel-art alpha
32
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
33
+ elif ada_mode in ['ada_sola', 'ada_sola_bias']:
34
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
35
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
36
+ self.scaling = alpha / r
37
+ if ada_mode == 'ada_sola_bias':
38
+ # take bias out for consistency
39
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
40
+ else:
41
+ raise NotImplementedError
42
+
43
+ def forward(self, time_token=None, time_ada=None):
44
+ if self.ada_mode == 'ada':
45
+ assert time_ada is None
46
+ B = time_token.shape[0]
47
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
48
+ elif self.ada_mode == 'ada_single':
49
+ B = time_ada.shape[0]
50
+ time_ada = time_ada.reshape(B, 6, -1)
51
+ time_ada = self.scale_shift_table[None] + time_ada
52
+ elif self.ada_mode in ['ada_sola', 'ada_sola_bias']:
53
+ B = time_ada.shape[0]
54
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
55
+ time_ada = time_ada + time_ada_lora
56
+ time_ada = time_ada.reshape(B, 6, -1)
57
+ if self.scale_shift_table is not None:
58
+ time_ada = self.scale_shift_table[None] + time_ada
59
+ else:
60
+ raise NotImplementedError
61
+ return time_ada
62
+
63
+
64
+ class DiTBlock(nn.Module):
65
+ """
66
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
67
+ """
68
+ def __init__(
69
+ self,
70
+ dim,
71
+ context_dim=None,
72
+ num_heads=8,
73
+ mlp_ratio=4.,
74
+ qkv_bias=False,
75
+ qk_scale=None,
76
+ qk_norm=None,
77
+ act_layer='gelu',
78
+ norm_layer=nn.LayerNorm,
79
+ time_fusion='none',
80
+ ada_sola_rank=None,
81
+ ada_sola_alpha=None,
82
+ skip=False,
83
+ skip_norm=False,
84
+ rope_mode='none',
85
+ context_norm=False,
86
+ use_checkpoint=False
87
+ ):
88
+
89
+ super().__init__()
90
+ self.norm1 = norm_layer(dim)
91
+ self.attn = Attention(
92
+ dim=dim,
93
+ num_heads=num_heads,
94
+ qkv_bias=qkv_bias,
95
+ qk_scale=qk_scale,
96
+ qk_norm=qk_norm,
97
+ rope_mode=rope_mode
98
+ )
99
+
100
+ if context_dim is not None:
101
+ self.use_context = True
102
+ self.cross_attn = Attention(
103
+ dim=dim,
104
+ num_heads=num_heads,
105
+ context_dim=context_dim,
106
+ qkv_bias=qkv_bias,
107
+ qk_scale=qk_scale,
108
+ qk_norm=qk_norm,
109
+ rope_mode='none'
110
+ )
111
+ self.norm2 = norm_layer(dim)
112
+ if context_norm:
113
+ self.norm_context = norm_layer(context_dim)
114
+ else:
115
+ self.norm_context = nn.Identity()
116
+ else:
117
+ self.use_context = False
118
+
119
+ self.norm3 = norm_layer(dim)
120
+ self.mlp = FeedForward(
121
+ dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0
122
+ )
123
+
124
+ self.use_adanorm = True if time_fusion != 'token' else False
125
+ if self.use_adanorm:
126
+ self.adaln = AdaLN(
127
+ dim,
128
+ ada_mode=time_fusion,
129
+ r=ada_sola_rank,
130
+ alpha=ada_sola_alpha
131
+ )
132
+ if skip:
133
+ self.skip_norm = norm_layer(2 *
134
+ dim) if skip_norm else nn.Identity()
135
+ self.skip_linear = nn.Linear(2 * dim, dim)
136
+ else:
137
+ self.skip_linear = None
138
+
139
+ self.use_checkpoint = use_checkpoint
140
+
141
+ def forward(
142
+ self,
143
+ x,
144
+ time_token=None,
145
+ time_ada=None,
146
+ skip=None,
147
+ context=None,
148
+ x_mask=None,
149
+ context_mask=None,
150
+ extras=None
151
+ ):
152
+ if self.use_checkpoint:
153
+ return checkpoint(
154
+ self._forward,
155
+ x,
156
+ time_token,
157
+ time_ada,
158
+ skip,
159
+ context,
160
+ x_mask,
161
+ context_mask,
162
+ extras,
163
+ use_reentrant=False
164
+ )
165
+ else:
166
+ return self._forward(
167
+ x, time_token, time_ada, skip, context, x_mask, context_mask,
168
+ extras
169
+ )
170
+
171
+ def _forward(
172
+ self,
173
+ x,
174
+ time_token=None,
175
+ time_ada=None,
176
+ skip=None,
177
+ context=None,
178
+ x_mask=None,
179
+ context_mask=None,
180
+ extras=None
181
+ ):
182
+ B, T, C = x.shape
183
+ if self.skip_linear is not None:
184
+ assert skip is not None
185
+ cat = torch.cat([x, skip], dim=-1)
186
+ cat = self.skip_norm(cat)
187
+ x = self.skip_linear(cat)
188
+
189
+ if self.use_adanorm:
190
+ time_ada = self.adaln(time_token, time_ada)
191
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
192
+ gate_mlp) = time_ada.chunk(6, dim=1)
193
+
194
+ # self attention
195
+ if self.use_adanorm:
196
+ x_norm = film_modulate(
197
+ self.norm1(x), shift=shift_msa, scale=scale_msa
198
+ )
199
+ x = x + (1 - gate_msa) * self.attn(
200
+ x_norm, context=None, context_mask=x_mask, extras=extras
201
+ )
202
+ else:
203
+ x = x + self.attn(
204
+ self.norm1(x),
205
+ context=None,
206
+ context_mask=x_mask,
207
+ extras=extras
208
+ )
209
+
210
+ # cross attention
211
+ if self.use_context:
212
+ assert context is not None
213
+ x = x + self.cross_attn(
214
+ x=self.norm2(x),
215
+ context=self.norm_context(context),
216
+ context_mask=context_mask,
217
+ extras=extras
218
+ )
219
+
220
+ # mlp
221
+ if self.use_adanorm:
222
+ x_norm = film_modulate(
223
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
224
+ )
225
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
226
+ else:
227
+ x = x + self.mlp(self.norm3(x))
228
+
229
+ return x
230
+
231
+
232
+ class FinalBlock(nn.Module):
233
+ def __init__(
234
+ self,
235
+ embed_dim,
236
+ patch_size,
237
+ in_chans,
238
+ img_size,
239
+ input_type='2d',
240
+ norm_layer=nn.LayerNorm,
241
+ use_conv=True,
242
+ use_adanorm=True
243
+ ):
244
+ super().__init__()
245
+ self.in_chans = in_chans
246
+ self.img_size = img_size
247
+ self.input_type = input_type
248
+
249
+ self.norm = norm_layer(embed_dim)
250
+ if use_adanorm:
251
+ self.use_adanorm = True
252
+ else:
253
+ self.use_adanorm = False
254
+
255
+ if input_type == '2d':
256
+ self.patch_dim = patch_size**2 * in_chans
257
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
258
+ if use_conv:
259
+ self.final_layer = nn.Conv2d(
260
+ self.in_chans, self.in_chans, 3, padding=1
261
+ )
262
+ else:
263
+ self.final_layer = nn.Identity()
264
+
265
+ elif input_type == '1d':
266
+ self.patch_dim = patch_size * in_chans
267
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
268
+ if use_conv:
269
+ self.final_layer = nn.Conv1d(
270
+ self.in_chans, self.in_chans, 3, padding=1
271
+ )
272
+ else:
273
+ self.final_layer = nn.Identity()
274
+
275
+ def forward(self, x, time_ada=None, extras=0):
276
+ B, T, C = x.shape
277
+ x = x[:, extras:, :]
278
+ # only handle generation target
279
+ if self.use_adanorm:
280
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
281
+ x = film_modulate(self.norm(x), shift, scale)
282
+ else:
283
+ x = self.norm(x)
284
+ x = self.linear(x)
285
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
286
+ x = self.final_layer(x)
287
+ return x
288
+
289
+
290
+ class UDiT(nn.Module):
291
+ def __init__(
292
+ self,
293
+ img_size=224,
294
+ patch_size=16,
295
+ in_chans=3,
296
+ input_type='2d',
297
+ out_chans=None,
298
+ embed_dim=768,
299
+ depth=12,
300
+ num_heads=12,
301
+ mlp_ratio=4.,
302
+ qkv_bias=False,
303
+ qk_scale=None,
304
+ qk_norm=None,
305
+ act_layer='gelu',
306
+ norm_layer='layernorm',
307
+ context_norm=False,
308
+ use_checkpoint=False,
309
+ # time fusion ada or token
310
+ time_fusion='token',
311
+ ada_sola_rank=None,
312
+ ada_sola_alpha=None,
313
+ cls_dim=None,
314
+ # max length is only used for concat
315
+ context_dim=768,
316
+ context_fusion='concat',
317
+ context_max_length=128,
318
+ context_pe_method='sinu',
319
+ pe_method='abs',
320
+ rope_mode='none',
321
+ use_conv=True,
322
+ skip=True,
323
+ skip_norm=True
324
+ ):
325
+ super().__init__()
326
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
327
+
328
+ # input
329
+ self.in_chans = in_chans
330
+ self.input_type = input_type
331
+ if self.input_type == '2d':
332
+ num_patches = (img_size[0] //
333
+ patch_size) * (img_size[1] // patch_size)
334
+ elif self.input_type == '1d':
335
+ num_patches = img_size // patch_size
336
+ self.patch_embed = PatchEmbed(
337
+ patch_size=patch_size,
338
+ in_chans=in_chans,
339
+ embed_dim=embed_dim,
340
+ input_type=input_type
341
+ )
342
+ out_chans = in_chans if out_chans is None else out_chans
343
+ self.out_chans = out_chans
344
+
345
+ # position embedding
346
+ self.rope = rope_mode
347
+ self.x_pe = PE_wrapper(
348
+ dim=embed_dim, method=pe_method, length=num_patches
349
+ )
350
+
351
+ logger.info(f'x position embedding: {pe_method}')
352
+ logger.info(f'rope mode: {self.rope}')
353
+
354
+ # time embed
355
+ self.time_embed = TimestepEmbedder(embed_dim)
356
+ self.time_fusion = time_fusion
357
+ self.use_adanorm = False
358
+
359
+ # cls embed
360
+ if cls_dim is not None:
361
+ self.cls_embed = nn.Sequential(
362
+ nn.Linear(cls_dim, embed_dim, bias=True),
363
+ nn.SiLU(),
364
+ nn.Linear(embed_dim, embed_dim, bias=True),
365
+ )
366
+ else:
367
+ self.cls_embed = None
368
+
369
+ # time fusion
370
+ if time_fusion == 'token':
371
+ # put token at the beginning of sequence
372
+ self.extras = 2 if self.cls_embed else 1
373
+ self.time_pe = PE_wrapper(
374
+ dim=embed_dim, method='abs', length=self.extras
375
+ )
376
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
377
+ self.use_adanorm = True
378
+ # aviod repetitive silu for each adaln block
379
+ self.time_act = nn.SiLU()
380
+ self.extras = 0
381
+ self.time_ada_final = nn.Linear(
382
+ embed_dim, 2 * embed_dim, bias=True
383
+ )
384
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
385
+ # shared adaln
386
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
387
+ else:
388
+ self.time_ada = None
389
+ else:
390
+ raise NotImplementedError
391
+ logger.info(f'time fusion mode: {self.time_fusion}')
392
+
393
+ # context
394
+ # use a simple projection
395
+ self.use_context = False
396
+ self.context_cross = False
397
+ self.context_max_length = context_max_length
398
+ self.context_fusion = 'none'
399
+ if context_dim is not None:
400
+ self.use_context = True
401
+ self.context_embed = nn.Sequential(
402
+ nn.Linear(context_dim, embed_dim, bias=True),
403
+ nn.SiLU(),
404
+ nn.Linear(embed_dim, embed_dim, bias=True),
405
+ )
406
+ self.context_fusion = context_fusion
407
+ if context_fusion == 'concat' or context_fusion == 'joint':
408
+ self.extras += context_max_length
409
+ self.context_pe = PE_wrapper(
410
+ dim=embed_dim,
411
+ method=context_pe_method,
412
+ length=context_max_length
413
+ )
414
+ # no cross attention layers
415
+ context_dim = None
416
+ elif context_fusion == 'cross':
417
+ self.context_pe = PE_wrapper(
418
+ dim=embed_dim,
419
+ method=context_pe_method,
420
+ length=context_max_length
421
+ )
422
+ self.context_cross = True
423
+ context_dim = embed_dim
424
+ else:
425
+ raise NotImplementedError
426
+ logger.info(f'context fusion mode: {context_fusion}')
427
+ logger.info(f'context position embedding: {context_pe_method}')
428
+
429
+ self.use_skip = skip
430
+
431
+ # norm layers
432
+ if norm_layer == 'layernorm':
433
+ norm_layer = nn.LayerNorm
434
+ elif norm_layer == 'rmsnorm':
435
+ norm_layer = RMSNorm
436
+ else:
437
+ raise NotImplementedError
438
+
439
+ logger.info(f'use long skip connection: {skip}')
440
+ self.in_blocks = nn.ModuleList([
441
+ DiTBlock(
442
+ dim=embed_dim,
443
+ context_dim=context_dim,
444
+ num_heads=num_heads,
445
+ mlp_ratio=mlp_ratio,
446
+ qkv_bias=qkv_bias,
447
+ qk_scale=qk_scale,
448
+ qk_norm=qk_norm,
449
+ act_layer=act_layer,
450
+ norm_layer=norm_layer,
451
+ time_fusion=time_fusion,
452
+ ada_sola_rank=ada_sola_rank,
453
+ ada_sola_alpha=ada_sola_alpha,
454
+ skip=False,
455
+ skip_norm=False,
456
+ rope_mode=self.rope,
457
+ context_norm=context_norm,
458
+ use_checkpoint=use_checkpoint
459
+ ) for _ in range(depth // 2)
460
+ ])
461
+
462
+ self.mid_block = DiTBlock(
463
+ dim=embed_dim,
464
+ context_dim=context_dim,
465
+ num_heads=num_heads,
466
+ mlp_ratio=mlp_ratio,
467
+ qkv_bias=qkv_bias,
468
+ qk_scale=qk_scale,
469
+ qk_norm=qk_norm,
470
+ act_layer=act_layer,
471
+ norm_layer=norm_layer,
472
+ time_fusion=time_fusion,
473
+ ada_sola_rank=ada_sola_rank,
474
+ ada_sola_alpha=ada_sola_alpha,
475
+ skip=False,
476
+ skip_norm=False,
477
+ rope_mode=self.rope,
478
+ context_norm=context_norm,
479
+ use_checkpoint=use_checkpoint
480
+ )
481
+
482
+ self.out_blocks = nn.ModuleList([
483
+ DiTBlock(
484
+ dim=embed_dim,
485
+ context_dim=context_dim,
486
+ num_heads=num_heads,
487
+ mlp_ratio=mlp_ratio,
488
+ qkv_bias=qkv_bias,
489
+ qk_scale=qk_scale,
490
+ qk_norm=qk_norm,
491
+ act_layer=act_layer,
492
+ norm_layer=norm_layer,
493
+ time_fusion=time_fusion,
494
+ ada_sola_rank=ada_sola_rank,
495
+ ada_sola_alpha=ada_sola_alpha,
496
+ skip=skip,
497
+ skip_norm=skip_norm,
498
+ rope_mode=self.rope,
499
+ context_norm=context_norm,
500
+ use_checkpoint=use_checkpoint
501
+ ) for _ in range(depth // 2)
502
+ ])
503
+
504
+ # FinalLayer block
505
+ self.use_conv = use_conv
506
+ self.final_block = FinalBlock(
507
+ embed_dim=embed_dim,
508
+ patch_size=patch_size,
509
+ img_size=img_size,
510
+ in_chans=out_chans,
511
+ input_type=input_type,
512
+ norm_layer=norm_layer,
513
+ use_conv=use_conv,
514
+ use_adanorm=self.use_adanorm
515
+ )
516
+ self.initialize_weights()
517
+
518
+ def _init_ada(self):
519
+ if self.time_fusion == 'ada':
520
+ nn.init.constant_(self.time_ada_final.weight, 0)
521
+ nn.init.constant_(self.time_ada_final.bias, 0)
522
+ for block in self.in_blocks:
523
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
524
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
525
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
526
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
527
+ for block in self.out_blocks:
528
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
529
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
530
+ elif self.time_fusion == 'ada_single':
531
+ nn.init.constant_(self.time_ada.weight, 0)
532
+ nn.init.constant_(self.time_ada.bias, 0)
533
+ nn.init.constant_(self.time_ada_final.weight, 0)
534
+ nn.init.constant_(self.time_ada_final.bias, 0)
535
+ elif self.time_fusion in ['ada_sola', 'ada_sola_bias']:
536
+ nn.init.constant_(self.time_ada.weight, 0)
537
+ nn.init.constant_(self.time_ada.bias, 0)
538
+ nn.init.constant_(self.time_ada_final.weight, 0)
539
+ nn.init.constant_(self.time_ada_final.bias, 0)
540
+ for block in self.in_blocks:
541
+ nn.init.kaiming_uniform_(
542
+ block.adaln.lora_a.weight, a=math.sqrt(5)
543
+ )
544
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
545
+ nn.init.kaiming_uniform_(
546
+ self.mid_block.adaln.lora_a.weight, a=math.sqrt(5)
547
+ )
548
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
549
+ for block in self.out_blocks:
550
+ nn.init.kaiming_uniform_(
551
+ block.adaln.lora_a.weight, a=math.sqrt(5)
552
+ )
553
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
554
+
555
+ def initialize_weights(self):
556
+ # Basic init for all layers
557
+ def _basic_init(module):
558
+ if isinstance(module, nn.Linear):
559
+ nn.init.xavier_uniform_(module.weight)
560
+ if module.bias is not None:
561
+ nn.init.constant_(module.bias, 0)
562
+
563
+ self.apply(_basic_init)
564
+
565
+ # init patch Conv like Linear
566
+ w = self.patch_embed.proj.weight.data
567
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
568
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
569
+
570
+ # Zero-out AdaLN
571
+ if self.use_adanorm:
572
+ self._init_ada()
573
+
574
+ # Zero-out Cross Attention
575
+ if self.context_cross:
576
+ for block in self.in_blocks:
577
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
578
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
579
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
580
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
581
+ for block in self.out_blocks:
582
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
583
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
584
+
585
+ # Zero-out cls embedding
586
+ if self.cls_embed:
587
+ if self.use_adanorm:
588
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
589
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
590
+
591
+ # Zero-out Output
592
+ # might not zero-out this when using v-prediction
593
+ # it could be good when using noise-prediction
594
+ # nn.init.constant_(self.final_block.linear.weight, 0)
595
+ # nn.init.constant_(self.final_block.linear.bias, 0)
596
+ # if self.use_conv:
597
+ # nn.init.constant_(self.final_block.final_layer.weight.data, 0)
598
+ # nn.init.constant_(self.final_block.final_layer.bias, 0)
599
+
600
+ # init out Conv
601
+ if self.use_conv:
602
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
603
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
604
+
605
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
606
+ assert context.shape[-2] == self.context_max_length
607
+ # Check if either x_mask or context_mask is provided
608
+ B = x.shape[0]
609
+ # Create default masks if they are not provided
610
+ if x_mask is None:
611
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
612
+ if context_mask is None:
613
+ context_mask = torch.ones(
614
+ B, context.shape[-2], device=context.device
615
+ ).bool()
616
+ # Concatenate the masks along the second dimension (dim=1)
617
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
618
+ # Concatenate context and x along the second dimension (dim=1)
619
+ x = torch.cat((context, x), dim=1)
620
+ return x, x_mask
621
+
622
+ def forward(
623
+ self,
624
+ x,
625
+ timesteps,
626
+ context,
627
+ x_mask=None,
628
+ context_mask=None,
629
+ cls_token=None,
630
+ controlnet_skips=None,
631
+ ):
632
+ # make it compatible with int time step during inference
633
+ if timesteps.dim() == 0:
634
+ timesteps = timesteps.expand(x.shape[0]
635
+ ).to(x.device, dtype=torch.long)
636
+
637
+ x = self.patch_embed(x)
638
+ x = self.x_pe(x)
639
+
640
+ B, L, D = x.shape
641
+
642
+ if self.use_context:
643
+ context_token = self.context_embed(context)
644
+ context_token = self.context_pe(context_token)
645
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
646
+ x, x_mask = self._concat_x_context(
647
+ x=x,
648
+ context=context_token,
649
+ x_mask=x_mask,
650
+ context_mask=context_mask
651
+ )
652
+ context_token, context_mask = None, None
653
+ else:
654
+ context_token, context_mask = None, None
655
+
656
+ time_token = self.time_embed(timesteps)
657
+ if self.cls_embed:
658
+ cls_token = self.cls_embed(cls_token)
659
+ time_ada = None
660
+ time_ada_final = None
661
+ if self.use_adanorm:
662
+ if self.cls_embed:
663
+ time_token = time_token + cls_token
664
+ time_token = self.time_act(time_token)
665
+ time_ada_final = self.time_ada_final(time_token)
666
+ if self.time_ada is not None:
667
+ time_ada = self.time_ada(time_token)
668
+ else:
669
+ time_token = time_token.unsqueeze(dim=1)
670
+ if self.cls_embed:
671
+ cls_token = cls_token.unsqueeze(dim=1)
672
+ time_token = torch.cat([time_token, cls_token], dim=1)
673
+ time_token = self.time_pe(time_token)
674
+ x = torch.cat((time_token, x), dim=1)
675
+ if x_mask is not None:
676
+ x_mask = torch.cat([
677
+ torch.ones(B, time_token.shape[1],
678
+ device=x_mask.device).bool(), x_mask
679
+ ],
680
+ dim=1)
681
+ time_token = None
682
+
683
+ skips = []
684
+ for blk in self.in_blocks:
685
+ x = blk(
686
+ x=x,
687
+ time_token=time_token,
688
+ time_ada=time_ada,
689
+ skip=None,
690
+ context=context_token,
691
+ x_mask=x_mask,
692
+ context_mask=context_mask,
693
+ extras=self.extras
694
+ )
695
+ if self.use_skip:
696
+ skips.append(x)
697
+
698
+ x = self.mid_block(
699
+ x=x,
700
+ time_token=time_token,
701
+ time_ada=time_ada,
702
+ skip=None,
703
+ context=context_token,
704
+ x_mask=x_mask,
705
+ context_mask=context_mask,
706
+ extras=self.extras
707
+ )
708
+ for blk in self.out_blocks:
709
+ if self.use_skip:
710
+ skip = skips.pop()
711
+ if controlnet_skips:
712
+ # add to skip like u-net controlnet
713
+ skip = skip + controlnet_skips.pop()
714
+ else:
715
+ skip = None
716
+ if controlnet_skips:
717
+ # directly add to x
718
+ x = x + controlnet_skips.pop()
719
+
720
+ x = blk(
721
+ x=x,
722
+ time_token=time_token,
723
+ time_ada=time_ada,
724
+ skip=skip,
725
+ context=context_token,
726
+ x_mask=x_mask,
727
+ context_mask=context_mask,
728
+ extras=self.extras
729
+ )
730
+
731
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
732
+
733
+ return x
734
+
735
+
736
+ class MaskDiT(nn.Module):
737
+ def __init__(
738
+ self,
739
+ model: UDiT,
740
+ mae=False,
741
+ mae_prob=0.5,
742
+ mask_ratio=[0.25, 1.0],
743
+ mask_span=10,
744
+ ):
745
+ super().__init__()
746
+ self.model = model
747
+ self.mae = mae
748
+ if self.mae:
749
+ out_channel = model.out_chans
750
+ self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
751
+ self.mae_prob = mae_prob
752
+ self.mask_ratio = mask_ratio
753
+ self.mask_span = mask_span
754
+
755
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
756
+ B, D, L = gt.shape
757
+ if mae_mask_infer is None:
758
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
759
+ mask_ratios = mask_ratios.cpu().numpy()
760
+ mask = compute_mask_indices(
761
+ shape=[B, L],
762
+ padding_mask=None,
763
+ mask_prob=mask_ratios,
764
+ mask_length=self.mask_span,
765
+ mask_type="static",
766
+ mask_other=0.0,
767
+ min_masks=1,
768
+ no_overlap=False,
769
+ min_space=0,
770
+ )
771
+ mask = mask.unsqueeze(1).expand_as(gt)
772
+ else:
773
+ mask = mae_mask_infer
774
+ mask = mask.expand_as(gt)
775
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
776
+ return gt, mask.type_as(gt)
777
+
778
+ def forward(
779
+ self,
780
+ x,
781
+ timesteps,
782
+ context,
783
+ x_mask=None,
784
+ context_mask=None,
785
+ cls_token=None,
786
+ gt=None,
787
+ mae_mask_infer=None,
788
+ forward_model=True
789
+ ):
790
+ # todo: handle controlnet inside
791
+ mae_mask = torch.ones_like(x)
792
+ if self.mae:
793
+ if gt is not None:
794
+ B, D, L = gt.shape
795
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio
796
+ ).to(gt.device)
797
+ gt, mae_mask = self.random_masking(
798
+ gt, mask_ratios, mae_mask_infer
799
+ )
800
+ # apply mae only to the selected batches
801
+ if mae_mask_infer is None:
802
+ # determine mae batch
803
+ mae_batch = torch.rand(B) < self.mae_prob
804
+ gt[~mae_batch] = self.mask_embed.view(
805
+ 1, D, 1
806
+ ).expand_as(gt)[~mae_batch]
807
+ mae_mask[~mae_batch] = 1.0
808
+ else:
809
+ B, D, L = x.shape
810
+ gt = self.mask_embed.view(1, D, 1).expand_as(x)
811
+ x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
812
+
813
+ if forward_model:
814
+ x = self.model(
815
+ x=x,
816
+ timesteps=timesteps,
817
+ context=context,
818
+ x_mask=x_mask,
819
+ context_mask=context_mask,
820
+ cls_token=cls_token
821
+ )
822
+ # logger.info(mae_mask[:, 0, :].sum(dim=-1))
823
+ return x, mae_mask
models/dit/modules.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch.cuda.amp import autocast
7
+ import math
8
+ import einops
9
+ from einops import rearrange, repeat
10
+ from inspect import isfunction
11
+
12
+
13
+ def trunc_normal_(tensor, mean, std, a, b):
14
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
15
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
16
+ def norm_cdf(x):
17
+ # Computes standard normal cumulative distribution function
18
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
19
+
20
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
21
+ warnings.warn(
22
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
23
+ "The distribution of values may be incorrect.",
24
+ stacklevel=2
25
+ )
26
+
27
+ with torch.no_grad():
28
+ # Values are generated by using a truncated uniform distribution and
29
+ # then using the inverse CDF for the normal distribution.
30
+ # Get upper and lower cdf values
31
+ l = norm_cdf((a - mean) / std)
32
+ u = norm_cdf((b - mean) / std)
33
+
34
+ # Uniformly fill tensor with values from [l, u], then translate to
35
+ # [2l-1, 2u-1].
36
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
37
+
38
+ # Use inverse cdf transform for normal distribution to get truncated
39
+ # standard normal
40
+ tensor.erfinv_()
41
+
42
+ # Transform to proper mean, std
43
+ tensor.mul_(std * math.sqrt(2.))
44
+ tensor.add_(mean)
45
+
46
+ # Clamp to ensure it's in the proper range
47
+ tensor.clamp_(min=a, max=b)
48
+ return tensor
49
+
50
+
51
+ # disable in checkpoint mode
52
+ # @torch.jit.script
53
+ def film_modulate(x, shift, scale):
54
+ return x * (1 + scale) + shift
55
+
56
+
57
+ def timestep_embedding(timesteps, dim, max_period=10000):
58
+ """
59
+ Create sinusoidal timestep embeddings.
60
+
61
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
62
+ These may be fractional.
63
+ :param dim: the dimension of the output.
64
+ :param max_period: controls the minimum frequency of the embeddings.
65
+ :return: an [N x dim] Tensor of positional embeddings.
66
+ """
67
+ half = dim // 2
68
+ freqs = torch.exp(
69
+ -math.log(max_period) *
70
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
71
+ ).to(device=timesteps.device)
72
+ args = timesteps[:, None].float() * freqs[None]
73
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
74
+ if dim % 2:
75
+ embedding = torch.cat([embedding,
76
+ torch.zeros_like(embedding[:, :1])],
77
+ dim=-1)
78
+ return embedding
79
+
80
+
81
+ class TimestepEmbedder(nn.Module):
82
+ """
83
+ Embeds scalar timesteps into vector representations.
84
+ """
85
+ def __init__(
86
+ self, hidden_size, frequency_embedding_size=256, out_size=None
87
+ ):
88
+ super().__init__()
89
+ if out_size is None:
90
+ out_size = hidden_size
91
+ self.mlp = nn.Sequential(
92
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
93
+ nn.SiLU(),
94
+ nn.Linear(hidden_size, out_size, bias=True),
95
+ )
96
+ self.frequency_embedding_size = frequency_embedding_size
97
+
98
+ def forward(self, t):
99
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
100
+ self.mlp[0].weight.dtype
101
+ )
102
+ t_emb = self.mlp(t_freq)
103
+ return t_emb
104
+
105
+
106
+ def patchify(imgs, patch_size, input_type='2d'):
107
+ if input_type == '2d':
108
+ x = einops.rearrange(
109
+ imgs,
110
+ 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)',
111
+ p1=patch_size,
112
+ p2=patch_size
113
+ )
114
+ elif input_type == '1d':
115
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
116
+ return x
117
+
118
+
119
+ def unpatchify(x, channels=3, input_type='2d', img_size=None):
120
+ if input_type == '2d':
121
+ patch_size = int((x.shape[2] // channels)**0.5)
122
+ # h = w = int(x.shape[1] ** .5)
123
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
124
+ assert h * w == x.shape[1] and patch_size**2 * channels == x.shape[2]
125
+ x = einops.rearrange(
126
+ x,
127
+ 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)',
128
+ h=h,
129
+ p1=patch_size,
130
+ p2=patch_size
131
+ )
132
+ elif input_type == '1d':
133
+ patch_size = int((x.shape[2] // channels))
134
+ h = x.shape[1]
135
+ assert patch_size * channels == x.shape[2]
136
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
137
+ return x
138
+
139
+
140
+ class PatchEmbed(nn.Module):
141
+ """
142
+ Image to Patch Embedding
143
+ """
144
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
145
+ super().__init__()
146
+ self.patch_size = patch_size
147
+ self.input_type = input_type
148
+ if input_type == '2d':
149
+ self.proj = nn.Conv2d(
150
+ in_chans,
151
+ embed_dim,
152
+ kernel_size=patch_size,
153
+ stride=patch_size,
154
+ bias=True
155
+ )
156
+ elif input_type == '1d':
157
+ self.proj = nn.Conv1d(
158
+ in_chans,
159
+ embed_dim,
160
+ kernel_size=patch_size,
161
+ stride=patch_size,
162
+ bias=True
163
+ )
164
+
165
+ def forward(self, x):
166
+ if self.input_type == '2d':
167
+ B, C, H, W = x.shape
168
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
169
+ elif self.input_type == '1d':
170
+ B, C, H = x.shape
171
+ assert H % self.patch_size == 0
172
+
173
+ x = self.proj(x).flatten(2).transpose(1, 2)
174
+ return x
175
+
176
+
177
+ class PositionalConvEmbedding(nn.Module):
178
+ """
179
+ Convolutional positional embedding used in F5-TTS.
180
+ """
181
+ def __init__(self, dim=768, kernel_size=31, groups=16):
182
+ super().__init__()
183
+ assert kernel_size % 2 != 0
184
+ self.conv1d = nn.Sequential(
185
+ nn.Conv1d(
186
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
187
+ ),
188
+ nn.Mish(),
189
+ nn.Conv1d(
190
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
191
+ ),
192
+ nn.Mish(),
193
+ )
194
+
195
+ def forward(self, x):
196
+ # B T C
197
+ x = self.conv1d(x.transpose(1, 2))
198
+ x = x.transpose(1, 2)
199
+ return x
200
+
201
+
202
+ class SinusoidalPositionalEncoding(nn.Module):
203
+ def __init__(self, dim, length):
204
+ super(SinusoidalPositionalEncoding, self).__init__()
205
+ self.length = length
206
+ self.dim = dim
207
+ self.register_buffer(
208
+ 'pe', self._generate_positional_encoding(length, dim)
209
+ )
210
+
211
+ def _generate_positional_encoding(self, length, dim):
212
+ pe = torch.zeros(length, dim)
213
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
214
+ div_term = torch.exp(
215
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)
216
+ )
217
+
218
+ pe[:, 0::2] = torch.sin(position * div_term)
219
+ pe[:, 1::2] = torch.cos(position * div_term)
220
+
221
+ pe = pe.unsqueeze(0)
222
+ return pe
223
+
224
+ def forward(self, x):
225
+ x = x + self.pe[:, :x.size(1)]
226
+ return x
227
+
228
+
229
+ class PE_wrapper(nn.Module):
230
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
231
+ super().__init__()
232
+ self.method = method
233
+ if method == 'abs':
234
+ # init absolute pe like UViT
235
+ self.length = length
236
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
237
+ trunc_normal_(self.abs_pe, mean=0.0, std=.02, a=-.04, b=.04)
238
+ elif method == 'conv':
239
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
240
+ elif method == 'sinu':
241
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
242
+ elif method == 'none':
243
+ # skip pe
244
+ self.id = nn.Identity()
245
+ else:
246
+ raise NotImplementedError
247
+
248
+ def forward(self, x):
249
+ if self.method == 'abs':
250
+ _, L, _ = x.shape
251
+ assert L <= self.length
252
+ x = x + self.abs_pe[:, :L, :]
253
+ elif self.method == 'conv':
254
+ x = x + self.conv_pe(x)
255
+ elif self.method == 'sinu':
256
+ x = self.sinu_pe(x)
257
+ elif self.method == 'none':
258
+ x = self.id(x)
259
+ else:
260
+ raise NotImplementedError
261
+ return x
262
+
263
+
264
+ class RMSNorm(torch.nn.Module):
265
+ def __init__(self, dim: int, eps: float = 1e-6):
266
+ """
267
+ Initialize the RMSNorm normalization layer.
268
+
269
+ Args:
270
+ dim (int): The dimension of the input tensor.
271
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
272
+
273
+ Attributes:
274
+ eps (float): A small value added to the denominator for numerical stability.
275
+ weight (nn.Parameter): Learnable scaling parameter.
276
+
277
+ """
278
+ super().__init__()
279
+ self.eps = eps
280
+ self.weight = nn.Parameter(torch.ones(dim))
281
+
282
+ def _norm(self, x):
283
+ """
284
+ Apply the RMSNorm normalization to the input tensor.
285
+
286
+ Args:
287
+ x (torch.Tensor): The input tensor.
288
+
289
+ Returns:
290
+ torch.Tensor: The normalized tensor.
291
+
292
+ """
293
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
294
+
295
+ def forward(self, x):
296
+ """
297
+ Forward pass through the RMSNorm layer.
298
+
299
+ Args:
300
+ x (torch.Tensor): The input tensor.
301
+
302
+ Returns:
303
+ torch.Tensor: The output tensor after applying RMSNorm.
304
+
305
+ """
306
+ output = self._norm(x.float()).type_as(x)
307
+ return output * self.weight
308
+
309
+
310
+ class GELU(nn.Module):
311
+ def __init__(
312
+ self,
313
+ dim_in: int,
314
+ dim_out: int,
315
+ approximate: str = "none",
316
+ bias: bool = True
317
+ ):
318
+ super().__init__()
319
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
320
+ self.approximate = approximate
321
+
322
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
323
+ if gate.device.type != "mps":
324
+ return F.gelu(gate, approximate=self.approximate)
325
+ # mps: gelu is not implemented for float16
326
+ return F.gelu(
327
+ gate.to(dtype=torch.float32), approximate=self.approximate
328
+ ).to(dtype=gate.dtype)
329
+
330
+ def forward(self, hidden_states):
331
+ hidden_states = self.proj(hidden_states)
332
+ hidden_states = self.gelu(hidden_states)
333
+ return hidden_states
334
+
335
+
336
+ class GEGLU(nn.Module):
337
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
338
+ super().__init__()
339
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
340
+
341
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
342
+ if gate.device.type != "mps":
343
+ return F.gelu(gate)
344
+ # mps: gelu is not implemented for float16
345
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
346
+
347
+ def forward(self, hidden_states):
348
+ hidden_states = self.proj(hidden_states)
349
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
350
+ return hidden_states * self.gelu(gate)
351
+
352
+
353
+ class ApproximateGELU(nn.Module):
354
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
355
+ super().__init__()
356
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
357
+
358
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
359
+ x = self.proj(x)
360
+ return x * torch.sigmoid(1.702 * x)
361
+
362
+
363
+ # disable in checkpoint mode
364
+ # @torch.jit.script
365
+ def snake_beta(x, alpha, beta):
366
+ return x + beta * torch.sin(x * alpha).pow(2)
367
+
368
+
369
+ class Snake(nn.Module):
370
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
371
+ super().__init__()
372
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
373
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
374
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
375
+ self.alpha.requires_grad = alpha_trainable
376
+ self.beta.requires_grad = alpha_trainable
377
+
378
+ def forward(self, x):
379
+ x = self.proj(x)
380
+ x = snake_beta(x, self.alpha, self.beta)
381
+ return x
382
+
383
+
384
+ class GESnake(nn.Module):
385
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
386
+ super().__init__()
387
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
388
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
389
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
390
+ self.alpha.requires_grad = alpha_trainable
391
+ self.beta.requires_grad = alpha_trainable
392
+
393
+ def forward(self, x):
394
+ x = self.proj(x)
395
+ x, gate = x.chunk(2, dim=-1)
396
+ return x * snake_beta(gate, self.alpha, self.beta)
397
+
398
+
399
+ class FeedForward(nn.Module):
400
+ def __init__(
401
+ self,
402
+ dim,
403
+ dim_out=None,
404
+ mult=4,
405
+ dropout=0.0,
406
+ activation_fn="geglu",
407
+ final_dropout=False,
408
+ inner_dim=None,
409
+ bias=True,
410
+ ):
411
+ super().__init__()
412
+ if inner_dim is None:
413
+ inner_dim = int(dim * mult)
414
+ dim_out = dim_out if dim_out is not None else dim
415
+
416
+ if activation_fn == "gelu":
417
+ act_fn = GELU(dim, inner_dim, bias=bias)
418
+ elif activation_fn == "gelu-approximate":
419
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
420
+ elif activation_fn == "geglu":
421
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
422
+ elif activation_fn == "geglu-approximate":
423
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
424
+ elif activation_fn == "snake":
425
+ act_fn = Snake(dim, inner_dim, bias=bias)
426
+ elif activation_fn == "gesnake":
427
+ act_fn = GESnake(dim, inner_dim, bias=bias)
428
+ else:
429
+ raise NotImplementedError
430
+
431
+ self.net = nn.ModuleList([])
432
+ # project in
433
+ self.net.append(act_fn)
434
+ # project dropout
435
+ self.net.append(nn.Dropout(dropout))
436
+ # project out
437
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
438
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
439
+ if final_dropout:
440
+ self.net.append(nn.Dropout(dropout))
441
+
442
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
443
+ for module in self.net:
444
+ hidden_states = module(hidden_states)
445
+ return hidden_states
models/dit/rotary.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ "this rope is faster than llama rope with jit script"
3
+
4
+
5
+ def rotate_half(x):
6
+ x1, x2 = x.chunk(2, dim=-1)
7
+ return torch.cat((-x2, x1), dim=-1)
8
+
9
+
10
+ # disable in checkpoint mode
11
+ # @torch.jit.script
12
+ def apply_rotary_pos_emb(x, cos, sin):
13
+ # NOTE: This could probably be moved to Triton
14
+ # Handle a possible sequence length mismatch in between q and k
15
+ cos = cos[:, :, :x.shape[-2], :]
16
+ sin = sin[:, :, :x.shape[-2], :]
17
+ return (x*cos) + (rotate_half(x) * sin)
18
+
19
+
20
+ class RotaryEmbedding(torch.nn.Module):
21
+ """
22
+ The rotary position embeddings from RoFormer_ (Su et. al).
23
+ A crucial insight from the method is that the query and keys are
24
+ transformed by rotation matrices which depend on the relative positions.
25
+
26
+ Other implementations are available in the Rotary Transformer repo_ and in
27
+ GPT-NeoX_, GPT-NeoX was an inspiration
28
+
29
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
30
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
31
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
32
+
33
+
34
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
35
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
36
+ """
37
+ def __init__(self, dim: int):
38
+ super().__init__()
39
+ # Generate and save the inverse frequency buffer (non trainable)
40
+ inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
41
+ self.register_buffer("inv_freq", inv_freq)
42
+ self._seq_len_cached = None
43
+ self._cos_cached = None
44
+ self._sin_cached = None
45
+
46
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
47
+ # expect input: B, H, L, D
48
+ seq_len = x.shape[seq_dimension]
49
+
50
+ # Reset the tables if the sequence length has changed,
51
+ # or if we're on a new device (possibly due to tracing for instance)
52
+ # also make sure dtype wont change
53
+ if (
54
+ seq_len != self._seq_len_cached or
55
+ self._cos_cached.device != x.device or
56
+ self._cos_cached.dtype != x.dtype
57
+ ):
58
+ self._seq_len_cached = seq_len
59
+ t = torch.arange(
60
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
61
+ )
62
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
63
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
64
+
65
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
66
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
67
+
68
+ return self._cos_cached, self._sin_cached
69
+
70
+ def forward(self, q, k):
71
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
72
+ q.float(), seq_dimension=-2
73
+ )
74
+ if k is not None:
75
+ return (
76
+ apply_rotary_pos_emb(
77
+ q.float(), self._cos_cached, self._sin_cached
78
+ ).type_as(q),
79
+ apply_rotary_pos_emb(
80
+ k.float(), self._cos_cached, self._sin_cached
81
+ ).type_as(k),
82
+ )
83
+ else:
84
+ return (
85
+ apply_rotary_pos_emb(
86
+ q.float(), self._cos_cached, self._sin_cached
87
+ ).type_as(q), None
88
+ )
models/dit/span_mask.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ def compute_mask_indices(
7
+ shape: Tuple[int, int],
8
+ padding_mask: Optional[torch.Tensor],
9
+ mask_prob: float,
10
+ mask_length: int,
11
+ mask_type: str = "static",
12
+ mask_other: float = 0.0,
13
+ min_masks: int = 0,
14
+ no_overlap: bool = False,
15
+ min_space: int = 0,
16
+ ) -> np.ndarray:
17
+ """
18
+ Computes random mask spans for a given shape
19
+
20
+ Args:
21
+ shape: the the shape for which to compute masks.
22
+ should be of size 2 where first element is batch size and 2nd is timesteps
23
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
24
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
25
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
26
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
27
+ mask_type: how to compute mask lengths
28
+ static = fixed size
29
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
30
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
31
+ poisson = sample from possion distribution with lambda = mask length
32
+ min_masks: minimum number of masked spans
33
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
34
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
35
+ """
36
+
37
+ bsz, all_sz = shape
38
+ mask = np.full((bsz, all_sz), False)
39
+
40
+ # Convert mask_prob to a NumPy array
41
+ mask_prob = np.array(mask_prob)
42
+
43
+ # Calculate all_num_mask for each element in the batch
44
+ all_num_mask = np.floor(
45
+ mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)
46
+ ).astype(int)
47
+
48
+ # Apply the max operation with min_masks for each element
49
+ all_num_mask = np.maximum(min_masks, all_num_mask)
50
+
51
+ mask_idcs = []
52
+ for i in range(bsz):
53
+ if padding_mask is not None:
54
+ sz = all_sz - padding_mask[i].long().sum().item()
55
+ num_mask = int(
56
+ # add a random number for probabilistic rounding
57
+ mask_prob * sz / float(mask_length) + np.random.rand()
58
+ )
59
+ num_mask = max(min_masks, num_mask)
60
+ else:
61
+ sz = all_sz
62
+ num_mask = all_num_mask[i]
63
+
64
+ if mask_type == "static":
65
+ lengths = np.full(num_mask, mask_length)
66
+ elif mask_type == "uniform":
67
+ lengths = np.random.randint(
68
+ mask_other, mask_length*2 + 1, size=num_mask
69
+ )
70
+ elif mask_type == "normal":
71
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
72
+ lengths = [max(1, int(round(x))) for x in lengths]
73
+ elif mask_type == "poisson":
74
+ lengths = np.random.poisson(mask_length, size=num_mask)
75
+ lengths = [int(round(x)) for x in lengths]
76
+ else:
77
+ raise Exception("unknown mask selection " + mask_type)
78
+
79
+ if sum(lengths) == 0:
80
+ lengths[0] = min(mask_length, sz - 1)
81
+
82
+ if no_overlap:
83
+ mask_idc = []
84
+
85
+ def arrange(s, e, length, keep_length):
86
+ span_start = np.random.randint(s, e - length)
87
+ mask_idc.extend(span_start + i for i in range(length))
88
+
89
+ new_parts = []
90
+ if span_start - s - min_space >= keep_length:
91
+ new_parts.append((s, span_start - min_space + 1))
92
+ if e - span_start - keep_length - min_space > keep_length:
93
+ new_parts.append((span_start + length + min_space, e))
94
+ return new_parts
95
+
96
+ parts = [(0, sz)]
97
+ min_length = min(lengths)
98
+ for length in sorted(lengths, reverse=True):
99
+ lens = np.fromiter(
100
+ (
101
+ e - s if e - s >= length + min_space else 0
102
+ for s, e in parts
103
+ ),
104
+ np.int,
105
+ )
106
+ l_sum = np.sum(lens)
107
+ if l_sum == 0:
108
+ break
109
+ probs = lens / np.sum(lens)
110
+ c = np.random.choice(len(parts), p=probs)
111
+ s, e = parts.pop(c)
112
+ parts.extend(arrange(s, e, length, min_length))
113
+ mask_idc = np.asarray(mask_idc)
114
+ else:
115
+ min_len = min(lengths)
116
+ if sz - min_len <= num_mask:
117
+ min_len = sz - num_mask - 1
118
+
119
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
120
+
121
+ mask_idc = np.asarray([
122
+ mask_idc[j] + offset for j in range(len(mask_idc))
123
+ for offset in range(lengths[j])
124
+ ])
125
+
126
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
127
+ # min_len = min([len(m) for m in mask_idcs])
128
+ for i, mask_idc in enumerate(mask_idcs):
129
+ # if len(mask_idc) > min_len:
130
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
131
+ mask[i, mask_idc] = True
132
+
133
+ return torch.tensor(mask)
134
+
135
+
136
+ if __name__ == '__main__':
137
+ mask = compute_mask_indices(
138
+ shape=[4, 500],
139
+ padding_mask=None,
140
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
141
+ mask_length=10,
142
+ mask_type="static",
143
+ mask_other=0.0,
144
+ min_masks=1,
145
+ no_overlap=False,
146
+ min_space=0,
147
+ )
148
+ print(mask)
149
+ print(mask.sum(dim=1))
models/flow_matching.py ADDED
@@ -0,0 +1,1267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union, List, Sequence
2
+
3
+ import inspect
4
+ import random
5
+
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import copy
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+ from diffusers import FlowMatchEulerDiscreteScheduler
15
+ from diffusers.training_utils import compute_density_for_timestep_sampling
16
+
17
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
18
+ from models.content_encoder.content_encoder import ContentEncoder
19
+ from models.content_adapter import ContentAdapterBase
20
+ from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase
21
+ from utils.torch_utilities import (
22
+ create_alignment_path, create_mask_from_length, loss_with_mask,
23
+ trim_or_pad_length
24
+ )
25
+ from safetensors.torch import load_file
26
+
27
+ class FlowMatchingMixin:
28
+ def __init__(
29
+ self,
30
+ cfg_drop_ratio: float = 0.2,
31
+ sample_strategy: str = 'normal',
32
+ num_train_steps: int = 1000
33
+ ) -> None:
34
+ r"""
35
+ Args:
36
+ cfg_drop_ratio (float): Dropout ratio for the autoencoder.
37
+ sample_strategy (str): Sampling strategy for timesteps during training.
38
+ num_train_steps (int): Number of training steps for the noise scheduler.
39
+ """
40
+ self.sample_strategy = sample_strategy
41
+ self.infer_noise_scheduler = FlowMatchEulerDiscreteScheduler(
42
+ num_train_timesteps=num_train_steps
43
+ )
44
+ self.train_noise_scheduler = copy.deepcopy(self.infer_noise_scheduler)
45
+
46
+ self.classifier_free_guidance = cfg_drop_ratio > 0.0
47
+ self.cfg_drop_ratio = cfg_drop_ratio
48
+
49
+ def get_input_target_and_timesteps(
50
+ self,
51
+ latent: torch.Tensor,
52
+ training: bool = True
53
+ ):
54
+ bsz = latent.shape[0]
55
+ noise = torch.randn_like(latent)
56
+
57
+ if training:
58
+ if self.sample_strategy == 'normal':
59
+ u = compute_density_for_timestep_sampling(
60
+ weighting_scheme="logit_normal",
61
+ batch_size=bsz,
62
+ logit_mean=0,
63
+ logit_std=1,
64
+ mode_scale=None,
65
+ )
66
+ elif self.sample_strategy == 'uniform':
67
+ u = torch.randn(bsz, )
68
+ else:
69
+ raise NotImplementedError(
70
+ f"{self.sample_strategy} samlping for timesteps is not supported now"
71
+ )
72
+ else:
73
+ u = torch.ones(bsz, ) / 2
74
+
75
+ indices = (u * self.train_noise_scheduler.config.num_train_timesteps
76
+ ).long()
77
+
78
+ # train_noise_scheduler.timesteps: a list from 1 ~ num_trainsteps with 1 as interval
79
+ timesteps = self.train_noise_scheduler.timesteps[indices].to(
80
+ device=latent.device
81
+ )
82
+ sigmas = self.get_sigmas(
83
+ timesteps, n_dim=latent.ndim, dtype=latent.dtype
84
+ )
85
+
86
+ noisy_latent = (1.0 - sigmas) * latent + sigmas * noise
87
+
88
+ target = noise - latent
89
+
90
+ return noisy_latent, target, timesteps
91
+
92
+ def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
93
+ device = timesteps.device
94
+
95
+ # a list from 1 declining to 1/num_train_steps
96
+ sigmas = self.train_noise_scheduler.sigmas.to(
97
+ device=device, dtype=dtype
98
+ )
99
+
100
+ schedule_timesteps = self.train_noise_scheduler.timesteps.to(device)
101
+ timesteps = timesteps.to(device)
102
+ step_indices = [(schedule_timesteps == t).nonzero().item()
103
+ for t in timesteps]
104
+
105
+ sigma = sigmas[step_indices].flatten()
106
+ while len(sigma.shape) < n_dim:
107
+ sigma = sigma.unsqueeze(-1)
108
+ return sigma
109
+
110
+ def retrieve_timesteps(
111
+ self,
112
+ num_inference_steps: Optional[int] = None,
113
+ device: Optional[Union[str, torch.device]] = None,
114
+ timesteps: Optional[List[int]] = None,
115
+ sigmas: Optional[List[float]] = None,
116
+ **kwargs,
117
+ ):
118
+ # used in inference, retrieve new timesteps on given inference timesteps
119
+ scheduler = self.infer_noise_scheduler
120
+
121
+ if timesteps is not None and sigmas is not None:
122
+ raise ValueError(
123
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
124
+ )
125
+ if timesteps is not None:
126
+ accepts_timesteps = "timesteps" in set(
127
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
128
+ )
129
+ if not accepts_timesteps:
130
+ raise ValueError(
131
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
132
+ f" timestep schedules. Please check whether you are using the correct scheduler."
133
+ )
134
+ scheduler.set_timesteps(
135
+ timesteps=timesteps, device=device, **kwargs
136
+ )
137
+ timesteps = scheduler.timesteps
138
+ num_inference_steps = len(timesteps)
139
+ elif sigmas is not None:
140
+ accept_sigmas = "sigmas" in set(
141
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
142
+ )
143
+ if not accept_sigmas:
144
+ raise ValueError(
145
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
146
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
147
+ )
148
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
149
+ timesteps = scheduler.timesteps
150
+ num_inference_steps = len(timesteps)
151
+ else:
152
+ scheduler.set_timesteps(
153
+ num_inference_steps, device=device, **kwargs
154
+ )
155
+ timesteps = scheduler.timesteps
156
+ return timesteps, num_inference_steps
157
+
158
+
159
+ class ContentEncoderAdapterMixin:
160
+ def __init__(
161
+ self,
162
+ content_encoder: ContentEncoder,
163
+ content_adapter: ContentAdapterBase | None = None
164
+ ):
165
+ self.content_encoder = content_encoder
166
+ self.content_adapter = content_adapter
167
+
168
+ def encode_content(
169
+ self,
170
+ content: list[Any],
171
+ task: list[str],
172
+ device: str | torch.device,
173
+ instruction: torch.Tensor | None = None,
174
+ instruction_lengths: torch.Tensor | None = None
175
+ ):
176
+ content_output: dict[
177
+ str, torch.Tensor] = self.content_encoder.encode_content(
178
+ content, task, device=device
179
+ )
180
+ content, content_mask = content_output["content"], content_output[
181
+ "content_mask"]
182
+
183
+ if instruction is not None:
184
+ instruction_mask = create_mask_from_length(instruction_lengths)
185
+ (
186
+ content,
187
+ content_mask,
188
+ global_duration_pred,
189
+ local_duration_pred,
190
+ ) = self.content_adapter(
191
+ content, content_mask, instruction, instruction_mask
192
+ )
193
+
194
+ return_dict = {
195
+ "content": content,
196
+ "content_mask": content_mask,
197
+ "length_aligned_content": content_output["length_aligned_content"],
198
+ }
199
+ if instruction is not None:
200
+ return_dict["global_duration_pred"] = global_duration_pred
201
+ return_dict["local_duration_pred"] = local_duration_pred
202
+
203
+ return return_dict
204
+
205
+
206
+ class SingleTaskCrossAttentionAudioFlowMatching(
207
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
208
+ FlowMatchingMixin, ContentEncoderAdapterMixin
209
+ ):
210
+ def __init__(
211
+ self,
212
+ autoencoder: nn.Module,
213
+ content_encoder: ContentEncoder,
214
+ backbone: nn.Module,
215
+ cfg_drop_ratio: float = 0.2,
216
+ sample_strategy: str = 'normal',
217
+ num_train_steps: int = 1000,
218
+ pretrained_ckpt: str | None = None,
219
+ ):
220
+ nn.Module.__init__(self)
221
+ FlowMatchingMixin.__init__(
222
+ self, cfg_drop_ratio, sample_strategy, num_train_steps
223
+ )
224
+ ContentEncoderAdapterMixin.__init__(
225
+ self, content_encoder=content_encoder
226
+ )
227
+
228
+ self.autoencoder = autoencoder
229
+ for param in self.autoencoder.parameters():
230
+ param.requires_grad = False
231
+
232
+ if hasattr(self.content_encoder, "audio_encoder"):
233
+ if self.content_encoder.audio_encoder is not None:
234
+ self.content_encoder.audio_encoder.model = self.autoencoder
235
+
236
+ self.backbone = backbone
237
+ self.dummy_param = nn.Parameter(torch.empty(0))
238
+
239
+ if pretrained_ckpt is not None:
240
+ print(f"Load pretrain FlowMatching model from {pretrained_ckpt}")
241
+ pretrained_state_dict = load_file(pretrained_ckpt)
242
+ self.load_pretrained(pretrained_state_dict)
243
+ # missing, unexpected = self.load_state_dict(pretrained_state_dict, strict=False)
244
+ # print("Missing keys:", missing)
245
+ # print("Unexpected keys:", unexpected)
246
+
247
+ # if content_encoder.embed_dim != 1024:
248
+ # self.context_proj = nn.Sequential(
249
+ # nn.Linear(content_encoder.embed_dim, 1024),
250
+ # nn.SiLU(),
251
+ # nn.Linear(1024, 1024),
252
+ # )
253
+ # else:
254
+ # self.context_proj = nn.Identity()
255
+
256
+ def forward(
257
+ self, content: list[Any], condition: list[Any], task: list[str],
258
+ waveform: torch.Tensor, waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs
259
+
260
+ ):
261
+ loss_reduce = self.training or (loss_reduce and not self.training)
262
+ device = self.dummy_param.device
263
+
264
+ self.autoencoder.eval()
265
+ with torch.no_grad():
266
+ latent, latent_mask = self.autoencoder.encode(
267
+ waveform.unsqueeze(1), waveform_lengths
268
+ )
269
+
270
+ content_dict = self.encode_content(content, task, device)
271
+ content, content_mask = content_dict["content"], content_dict[
272
+ "content_mask"]
273
+
274
+ # content = self.context_proj(content)
275
+
276
+ if self.training and self.classifier_free_guidance:
277
+ mask_indices = [
278
+ k for k in range(len(waveform))
279
+ if random.random() < self.cfg_drop_ratio
280
+ ]
281
+ if len(mask_indices) > 0:
282
+ content[mask_indices] = 0
283
+
284
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
285
+ latent,
286
+ training = self.training
287
+ )
288
+
289
+ pred: torch.Tensor = self.backbone(
290
+ x=noisy_latent,
291
+ timesteps=timesteps,
292
+ context=content,
293
+ x_mask=latent_mask,
294
+ context_mask=content_mask
295
+ )
296
+
297
+ diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none")
298
+ diff_loss = loss_with_mask(diff_loss, latent_mask.unsqueeze(1), reduce=loss_reduce)
299
+ #diff_loss = loss_with_mask(diff_loss, latent_mask.unsqueeze(1))
300
+ output = {"diff_loss": diff_loss}
301
+ return output
302
+
303
+ def iterative_denoise(
304
+ self, latent: torch.Tensor, timesteps: list[int], num_steps: int,
305
+ verbose: bool, cfg: bool, cfg_scale: float, backbone_input: dict
306
+ ):
307
+ progress_bar = tqdm(range(num_steps), disable=not verbose)
308
+
309
+ for i, timestep in enumerate(timesteps):
310
+ # expand the latent if we are doing classifier free guidance
311
+ if cfg:
312
+ latent_input = torch.cat([latent, latent])
313
+ else:
314
+ latent_input = latent
315
+
316
+ noise_pred: torch.Tensor = self.backbone(
317
+ x=latent_input, timesteps=timestep, **backbone_input
318
+ )
319
+
320
+ # perform guidance
321
+ if cfg:
322
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
323
+ noise_pred = noise_pred_uncond + cfg_scale * (
324
+ noise_pred_content - noise_pred_uncond
325
+ )
326
+
327
+ latent = self.infer_noise_scheduler.step(
328
+ noise_pred, timestep, latent
329
+ ).prev_sample
330
+
331
+ progress_bar.update(1)
332
+
333
+ progress_bar.close()
334
+
335
+ return latent
336
+
337
+ @torch.no_grad()
338
+ def inference(
339
+ self,
340
+ content: list[Any],
341
+ condition: list[Any],
342
+ task: list[str],
343
+ latent_shape: Sequence[int],
344
+ num_steps: int = 50,
345
+ sway_sampling_coef: float | None = -1.0,
346
+ guidance_scale: float = 3.0,
347
+ num_samples_per_content: int = 1,
348
+ disable_progress: bool = True,
349
+ **kwargs
350
+ ):
351
+ device = self.dummy_param.device
352
+ classifier_free_guidance = guidance_scale > 1.0
353
+ batch_size = len(content) * num_samples_per_content
354
+
355
+ if classifier_free_guidance:
356
+ content, content_mask = self.encode_content_classifier_free(
357
+ content, task, device, num_samples_per_content
358
+ )
359
+ else:
360
+ content_output: dict[
361
+ str, torch.Tensor] = self.content_encoder.encode_content(
362
+ content, task
363
+ )
364
+ content, content_mask = content_output["content"], content_output[
365
+ "content_mask"]
366
+ content = content.repeat_interleave(num_samples_per_content, 0)
367
+ content_mask = content_mask.repeat_interleave(
368
+ num_samples_per_content, 0
369
+ )
370
+
371
+ latent = self.prepare_latent(
372
+ batch_size, latent_shape, content.dtype, device
373
+ )
374
+
375
+ if not sway_sampling_coef:
376
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
377
+ else:
378
+ t = torch.linspace(0, 1, num_steps + 1)
379
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
380
+ sigmas = 1 - t
381
+ timesteps, num_steps = self.retrieve_timesteps(
382
+ num_steps, device, timesteps=None, sigmas=sigmas
383
+ )
384
+
385
+ latent = self.iterative_denoise(
386
+ latent=latent,
387
+ timesteps=timesteps,
388
+ num_steps=num_steps,
389
+ verbose=not disable_progress,
390
+ cfg=classifier_free_guidance,
391
+ cfg_scale=guidance_scale,
392
+ backbone_input={
393
+ "context": content,
394
+ "context_mask": content_mask,
395
+ },
396
+ )
397
+
398
+ waveform = self.autoencoder.decode(latent)
399
+
400
+ return waveform
401
+
402
+ def prepare_latent(
403
+ self, batch_size: int, latent_shape: Sequence[int], dtype: torch.dtype,
404
+ device: str
405
+ ):
406
+ shape = (batch_size, *latent_shape)
407
+ latent = randn_tensor(
408
+ shape, generator=None, device=device, dtype=dtype
409
+ )
410
+ return latent
411
+
412
+ def encode_content_classifier_free(
413
+ self,
414
+ content: list[Any],
415
+ task: list[str],
416
+ device,
417
+ num_samples_per_content: int = 1
418
+ ):
419
+ content_dict = self.content_encoder.encode_content(
420
+ content, task, device
421
+ )
422
+ content, content_mask = content_dict["content"], content_dict["content_mask"]
423
+ # content, content_mask = self.content_encoder.encode_content(
424
+ # content, task, device=device
425
+ # )
426
+
427
+ content = content.repeat_interleave(num_samples_per_content, 0)
428
+ content_mask = content_mask.repeat_interleave(
429
+ num_samples_per_content, 0
430
+ )
431
+
432
+ # get unconditional embeddings for classifier free guidance
433
+ uncond_content = torch.zeros_like(content)
434
+ uncond_content_mask = content_mask.detach().clone()
435
+
436
+ uncond_content = uncond_content.repeat_interleave(
437
+ num_samples_per_content, 0
438
+ )
439
+ uncond_content_mask = uncond_content_mask.repeat_interleave(
440
+ num_samples_per_content, 0
441
+ )
442
+
443
+ # For classifier free guidance, we need to do two forward passes.
444
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
445
+ content = torch.cat([uncond_content, content])
446
+ content_mask = torch.cat([uncond_content_mask, content_mask])
447
+
448
+ return content, content_mask
449
+
450
+ class MultiContentAudioFlowMatching(SingleTaskCrossAttentionAudioFlowMatching):
451
+ def __init__(
452
+ self,
453
+ autoencoder: AutoEncoderBase,
454
+ content_encoder: ContentEncoder,
455
+ backbone: nn.Module,
456
+ cfg_drop_ratio: float = 0.2,
457
+ sample_strategy: str = 'normal',
458
+ num_train_steps: int = 1000,
459
+ pretrained_ckpt: str | None = None,
460
+ embed_dim: int = 1024,
461
+ ):
462
+ super().__init__(
463
+ autoencoder=autoencoder,
464
+ content_encoder=content_encoder,
465
+ backbone=backbone,
466
+ cfg_drop_ratio=cfg_drop_ratio,
467
+ sample_strategy=sample_strategy,
468
+ num_train_steps=num_train_steps,
469
+ pretrained_ckpt=pretrained_ckpt,
470
+ )
471
+
472
+ def forward(
473
+ self,
474
+ content: list[Any],
475
+ duration: Sequence[float],
476
+ task: list[str],
477
+ waveform: torch.Tensor,
478
+ waveform_lengths: torch.Tensor,
479
+ loss_reduce: bool = True,
480
+ **kwargs
481
+ ):
482
+ device = self.dummy_param.device
483
+ loss_reduce = self.training or (loss_reduce and not self.training)
484
+
485
+ self.autoencoder.eval()
486
+
487
+ with torch.no_grad():
488
+ latent, latent_mask = self.autoencoder.encode(
489
+ waveform.unsqueeze(1), waveform_lengths
490
+ ) # latent [B, 128, 500/T=10s], latent_mask [B, 500/T=10s]
491
+
492
+ content_dict = self.encode_content(content, task, device)
493
+ context, context_mask, length_aligned_content = content_dict["content"], content_dict[
494
+ "content_mask"], content_dict["length_aligned_content"]
495
+
496
+ # --------------------------------------------------------------------
497
+ # prepare latent and noise
498
+ # --------------------------------------------------------------------
499
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
500
+ latent,
501
+ training = self.training
502
+ )
503
+
504
+ # --------------------------------------------------------------------
505
+ # prepare input to the backbone
506
+ # --------------------------------------------------------------------
507
+ # TODO compatility for 2D spectrogram VAE
508
+
509
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
510
+ time_aligned_content = trim_or_pad_length(
511
+ length_aligned_content, latent_length, 1
512
+ )
513
+
514
+ # --------------------------------------------------------------------
515
+ # classifier free guidance
516
+ # --------------------------------------------------------------------
517
+ if self.training and self.classifier_free_guidance:
518
+ mask_indices = [
519
+ k for k in range(len(waveform))
520
+ if random.random() < self.cfg_drop_ratio
521
+ ]
522
+ if len(mask_indices) > 0:
523
+ context[mask_indices] = 0
524
+ time_aligned_content[mask_indices] = 0
525
+
526
+ pred: torch.Tensor = self.backbone(
527
+ x=noisy_latent,
528
+ x_mask=latent_mask,
529
+ timesteps=timesteps,
530
+ context=context,
531
+ context_mask=context_mask,
532
+ time_aligned_context=time_aligned_content,
533
+ )
534
+
535
+ pred = pred.transpose(1, self.autoencoder.time_dim)
536
+ target = target.transpose(1, self.autoencoder.time_dim)
537
+ diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none")
538
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
539
+
540
+ return {
541
+ "diff_loss": diff_loss,
542
+ }
543
+
544
+ def inference(
545
+ self,
546
+ content: list[Any],
547
+ task: list[str],
548
+ latent_shape: Sequence[int],
549
+ num_steps: int = 50,
550
+ sway_sampling_coef: float | None = -1.0,
551
+ guidance_scale: float = 3.0,
552
+ disable_progress: bool = True,
553
+ **kwargs
554
+ ):
555
+ device = self.dummy_param.device
556
+ classifier_free_guidance = guidance_scale > 1.0
557
+ batch_size = len(content)
558
+
559
+
560
+ content_dict: dict[
561
+ str, torch.Tensor] = self.encode_content(
562
+ content, task, device
563
+ )
564
+ context, context_mask, length_aligned_content = \
565
+ content_dict["content"], content_dict[
566
+ "content_mask"], content_dict["length_aligned_content"]
567
+
568
+ shape = (batch_size, *latent_shape)
569
+ latent_length = shape[self.autoencoder.time_dim]
570
+ time_aligned_content = trim_or_pad_length(
571
+ length_aligned_content, latent_length, 1
572
+ )
573
+
574
+ # --------------------------------------------------------------------
575
+ # prepare unconditional input
576
+ # --------------------------------------------------------------------
577
+ if classifier_free_guidance:
578
+ uncond_time_aligned_content = torch.zeros_like(
579
+ time_aligned_content
580
+ )
581
+ uncond_context = torch.zeros_like(context)
582
+ uncond_context_mask = context_mask.detach().clone()
583
+ time_aligned_content = torch.cat([
584
+ uncond_time_aligned_content, time_aligned_content
585
+ ])
586
+ context = torch.cat([uncond_context, context])
587
+ context_mask = torch.cat([uncond_context_mask, context_mask])
588
+
589
+
590
+ latent = randn_tensor(
591
+ shape, generator=None, device=device, dtype=context.dtype
592
+ )
593
+
594
+ if not sway_sampling_coef:
595
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
596
+ else:
597
+ t = torch.linspace(0, 1, num_steps + 1)
598
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
599
+ sigmas = 1 - t
600
+ timesteps, num_steps = self.retrieve_timesteps(
601
+ num_steps, device, timesteps=None, sigmas=sigmas
602
+ )
603
+ latent = self.iterative_denoise(
604
+ latent=latent,
605
+ timesteps=timesteps,
606
+ num_steps=num_steps,
607
+ verbose=not disable_progress,
608
+ cfg=classifier_free_guidance,
609
+ cfg_scale=guidance_scale,
610
+ backbone_input={
611
+ "context": context,
612
+ "context_mask": context_mask,
613
+ "time_aligned_context": time_aligned_content,
614
+ }
615
+ )
616
+
617
+ waveform = self.autoencoder.decode(latent)
618
+ return waveform
619
+
620
+ class DurationAdapterMixin:
621
+ def __init__(
622
+ self,
623
+ latent_token_rate: int,
624
+ offset: float = 1.0,
625
+ frame_resolution: float | None = None
626
+ ):
627
+ self.latent_token_rate = latent_token_rate
628
+ self.offset = offset
629
+ self.frame_resolution = frame_resolution
630
+
631
+ def get_global_duration_loss(
632
+ self,
633
+ pred: torch.Tensor,
634
+ latent_mask: torch.Tensor,
635
+ reduce: bool = True,
636
+ ):
637
+ target = torch.log(
638
+ latent_mask.sum(1) / self.latent_token_rate + self.offset
639
+ )
640
+ loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none")
641
+ return loss
642
+
643
+ def get_local_duration_loss(
644
+ self, ground_truth: torch.Tensor, pred: torch.Tensor,
645
+ mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool
646
+ ):
647
+ n_frames = torch.round(ground_truth / self.frame_resolution)
648
+ target = torch.log(n_frames + self.offset)
649
+ loss = loss_with_mask(
650
+ (target - pred)**2,
651
+ mask,
652
+ reduce=False,
653
+ )
654
+ loss *= is_time_aligned
655
+ if reduce:
656
+ if is_time_aligned.sum().item() == 0:
657
+ loss *= 0.0
658
+ loss = loss.mean()
659
+ else:
660
+ loss = loss.sum() / is_time_aligned.sum()
661
+
662
+ return loss
663
+
664
+ def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor):
665
+ pred = torch.exp(pred) * mask
666
+ pred = torch.ceil(pred) - self.offset
667
+ pred *= self.frame_resolution
668
+ return pred
669
+
670
+ def prepare_global_duration(
671
+ self,
672
+ global_pred: torch.Tensor,
673
+ local_pred: torch.Tensor,
674
+ is_time_aligned: Sequence[bool],
675
+ use_local: bool = True,
676
+ ):
677
+ """
678
+ global_pred: predicted duration value, processed by logarithmic and offset
679
+ local_pred: predicted latent length
680
+ """
681
+ global_pred = torch.exp(global_pred) - self.offset
682
+ result = global_pred
683
+ # avoid error accumulation for each frame
684
+ if use_local:
685
+ pred_from_local = torch.round(local_pred * self.latent_token_rate)
686
+ pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
687
+ result[is_time_aligned] = pred_from_local[is_time_aligned]
688
+
689
+ return result
690
+
691
+ def expand_by_duration(
692
+ self,
693
+ x: torch.Tensor,
694
+ content_mask: torch.Tensor,
695
+ local_duration: torch.Tensor,
696
+ global_duration: torch.Tensor | None = None,
697
+ ):
698
+ n_latents = torch.round(local_duration * self.latent_token_rate)
699
+ if global_duration is not None:
700
+ latent_length = torch.round(
701
+ global_duration * self.latent_token_rate
702
+ )
703
+ else:
704
+ latent_length = n_latents.sum(1)
705
+ latent_mask = create_mask_from_length(latent_length).to(
706
+ content_mask.device
707
+ )
708
+ attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
709
+ align_path = create_alignment_path(n_latents, attn_mask)
710
+ expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
711
+ return expanded_x, latent_mask
712
+
713
+
714
+ class CrossAttentionAudioFlowMatching(
715
+ SingleTaskCrossAttentionAudioFlowMatching, DurationAdapterMixin
716
+ ):
717
+ def __init__(
718
+ self,
719
+ autoencoder: AutoEncoderBase,
720
+ content_encoder: ContentEncoder,
721
+ content_adapter: ContentAdapterBase,
722
+ backbone: nn.Module,
723
+ content_dim: int,
724
+ frame_resolution: float,
725
+ duration_offset: float = 1.0,
726
+ cfg_drop_ratio: float = 0.2,
727
+ sample_strategy: str = 'normal',
728
+ num_train_steps: int = 1000
729
+ ):
730
+ super().__init__(
731
+ autoencoder=autoencoder,
732
+ content_encoder=content_encoder,
733
+ backbone=backbone,
734
+ cfg_drop_ratio=cfg_drop_ratio,
735
+ sample_strategy=sample_strategy,
736
+ num_train_steps=num_train_steps,
737
+ )
738
+ ContentEncoderAdapterMixin.__init__(
739
+ self,
740
+ content_encoder=content_encoder,
741
+ content_adapter=content_adapter
742
+ )
743
+ DurationAdapterMixin.__init__(
744
+ self,
745
+ latent_token_rate=autoencoder.latent_token_rate,
746
+ offset=duration_offset
747
+ )
748
+
749
+ def encode_content_with_instruction(
750
+ self, content: list[Any], task: list[str], device,
751
+ instruction: torch.Tensor, instruction_lengths: torch.Tensor
752
+ ):
753
+ content_dict = self.encode_content(
754
+ content, task, device, instruction, instruction_lengths
755
+ )
756
+ return (
757
+ content_dict["content"], content_dict["content_mask"],
758
+ content_dict["global_duration_pred"],
759
+ content_dict["local_duration_pred"],
760
+ content_dict["length_aligned_content"]
761
+ )
762
+
763
+ def forward(
764
+ self,
765
+ content: list[Any],
766
+ task: list[str],
767
+ waveform: torch.Tensor,
768
+ waveform_lengths: torch.Tensor,
769
+ instruction: torch.Tensor,
770
+ instruction_lengths: torch.Tensor,
771
+ loss_reduce: bool = True,
772
+ **kwargs
773
+ ):
774
+ device = self.dummy_param.device
775
+ loss_reduce = self.training or (loss_reduce and not self.training)
776
+
777
+ self.autoencoder.eval()
778
+ with torch.no_grad():
779
+ latent, latent_mask = self.autoencoder.encode(
780
+ waveform.unsqueeze(1), waveform_lengths
781
+ )
782
+
783
+ content, content_mask, global_duration_pred, _, _ = \
784
+ self.encode_content_with_instruction(
785
+ content, task, device, instruction, instruction_lengths
786
+ )
787
+
788
+ global_duration_loss = self.get_global_duration_loss(
789
+ global_duration_pred, latent_mask, reduce=loss_reduce
790
+ )
791
+
792
+ if self.training and self.classifier_free_guidance:
793
+ mask_indices = [
794
+ k for k in range(len(waveform))
795
+ if random.random() < self.cfg_drop_ratio
796
+ ]
797
+ if len(mask_indices) > 0:
798
+ content[mask_indices] = 0
799
+
800
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
801
+ latent,
802
+ training = self.training
803
+ )
804
+
805
+ pred: torch.Tensor = self.backbone(
806
+ x=noisy_latent,
807
+ timesteps=timesteps,
808
+ context=content,
809
+ x_mask=latent_mask,
810
+ context_mask=content_mask,
811
+ )
812
+ pred = pred.transpose(1, self.autoencoder.time_dim)
813
+ target = target.transpose(1, self.autoencoder.time_dim)
814
+ diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none")
815
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
816
+
817
+ return {
818
+ "diff_loss": diff_loss,
819
+ "global_duration_loss": global_duration_loss,
820
+ }
821
+
822
+ @torch.no_grad()
823
+ def inference(
824
+ self,
825
+ content: list[Any],
826
+ condition: list[Any],
827
+ task: list[str],
828
+ is_time_aligned: Sequence[bool],
829
+ instruction: torch.Tensor,
830
+ instruction_lengths: torch.Tensor,
831
+ num_steps: int = 20,
832
+ sway_sampling_coef: float | None = -1.0,
833
+ guidance_scale: float = 3.0,
834
+ disable_progress=True,
835
+ use_gt_duration: bool = False,
836
+ **kwargs
837
+ ):
838
+ device = self.dummy_param.device
839
+ classifier_free_guidance = guidance_scale > 1.0
840
+
841
+ (
842
+ content,
843
+ content_mask,
844
+ global_duration_pred,
845
+ local_duration_pred,
846
+ _,
847
+ ) = self.encode_content_with_instruction(
848
+ content, task, device, instruction, instruction_lengths
849
+ )
850
+ batch_size = content.size(0)
851
+
852
+ if use_gt_duration:
853
+ raise NotImplementedError(
854
+ "Using ground truth global duration only is not implemented yet"
855
+ )
856
+
857
+ # prepare global duration
858
+ global_duration = self.prepare_global_duration(
859
+ global_duration_pred,
860
+ local_duration_pred,
861
+ is_time_aligned,
862
+ use_local=False
863
+ )
864
+ latent_length = torch.round(global_duration * self.latent_token_rate)
865
+ latent_mask = create_mask_from_length(latent_length).to(device)
866
+ max_latent_length = latent_mask.sum(1).max().item()
867
+
868
+ # prepare latent and noise
869
+ if classifier_free_guidance:
870
+ uncond_context = torch.zeros_like(content)
871
+ uncond_content_mask = content_mask.detach().clone()
872
+ context = torch.cat([uncond_context, content])
873
+ context_mask = torch.cat([uncond_content_mask, content_mask])
874
+ else:
875
+ context = content
876
+ context_mask = content_mask
877
+
878
+ latent_shape = tuple(
879
+ max_latent_length if dim is None else dim
880
+ for dim in self.autoencoder.latent_shape
881
+ )
882
+ shape = (batch_size, *latent_shape)
883
+ latent = randn_tensor(
884
+ shape, generator=None, device=device, dtype=content.dtype
885
+ )
886
+ if not sway_sampling_coef:
887
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
888
+ else:
889
+ t = torch.linspace(0, 1, num_steps + 1)
890
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
891
+ sigmas = 1 - t
892
+ timesteps, num_steps = self.retrieve_timesteps(
893
+ num_steps, device, timesteps=None, sigmas=sigmas
894
+ )
895
+ latent = self.iterative_denoise(
896
+ latent=latent,
897
+ timesteps=timesteps,
898
+ num_steps=num_steps,
899
+ verbose=not disable_progress,
900
+ cfg=classifier_free_guidance,
901
+ cfg_scale=guidance_scale,
902
+ backbone_input={
903
+ "x_mask": latent_mask,
904
+ "context": context,
905
+ "context_mask": context_mask,
906
+ }
907
+ )
908
+
909
+ waveform = self.autoencoder.decode(latent)
910
+ return waveform
911
+
912
+
913
+ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
914
+ def __init__(
915
+ self,
916
+ autoencoder: AutoEncoderBase,
917
+ content_encoder: ContentEncoder,
918
+ content_adapter: ContentAdapterBase,
919
+ backbone: nn.Module,
920
+ content_dim: int,
921
+ frame_resolution: float,
922
+ duration_offset: float = 1.0,
923
+ cfg_drop_ratio: float = 0.2,
924
+ sample_strategy: str = 'normal',
925
+ num_train_steps: int = 1000
926
+ ):
927
+
928
+ super().__init__(
929
+ autoencoder=autoencoder,
930
+ content_encoder=content_encoder,
931
+ content_adapter=content_adapter,
932
+ backbone=backbone,
933
+ content_dim=content_dim,
934
+ frame_resolution=frame_resolution,
935
+ duration_offset=duration_offset,
936
+ cfg_drop_ratio=cfg_drop_ratio,
937
+ sample_strategy=sample_strategy,
938
+ num_train_steps=num_train_steps
939
+ )
940
+ DurationAdapterMixin.__init__(
941
+ self,
942
+ latent_token_rate=autoencoder.latent_token_rate,
943
+ offset=duration_offset,
944
+ frame_resolution=frame_resolution
945
+ )
946
+ self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
947
+ self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
948
+
949
+ def get_backbone_input(
950
+ self, target_length: int, content: torch.Tensor,
951
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
952
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
953
+ ):
954
+ # TODO compatility for 2D spectrogram VAE
955
+ time_aligned_content = trim_or_pad_length(
956
+ time_aligned_content, target_length, 1
957
+ )
958
+ length_aligned_content = trim_or_pad_length(
959
+ length_aligned_content, target_length, 1
960
+ )
961
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
962
+ # length_aligned_content: from aligned input (f0/energy)
963
+ time_aligned_content = time_aligned_content + length_aligned_content
964
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
965
+ time_aligned_content.dtype
966
+ )
967
+
968
+ context = content
969
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
970
+ # only use the first dummy non time aligned embedding
971
+ context_mask = content_mask.detach().clone()
972
+ context_mask[is_time_aligned, 1:] = False
973
+
974
+ # truncate dummy non time aligned context
975
+ if is_time_aligned.sum().item() < content.size(0):
976
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
977
+ else:
978
+ trunc_nta_length = content.size(1)
979
+ context = context[:, :trunc_nta_length]
980
+ context_mask = context_mask[:, :trunc_nta_length]
981
+
982
+ return context, context_mask, time_aligned_content
983
+
984
+ def forward(
985
+ self,
986
+ content: list[Any],
987
+ duration: Sequence[float],
988
+ task: list[str],
989
+ is_time_aligned: Sequence[bool],
990
+ waveform: torch.Tensor,
991
+ waveform_lengths: torch.Tensor,
992
+ instruction: torch.Tensor,
993
+ instruction_lengths: torch.Tensor,
994
+ loss_reduce: bool = True,
995
+ **kwargs
996
+ ):
997
+ device = self.dummy_param.device
998
+ loss_reduce = self.training or (loss_reduce and not self.training)
999
+
1000
+ self.autoencoder.eval()
1001
+ with torch.no_grad():
1002
+ latent, latent_mask = self.autoencoder.encode(
1003
+ waveform.unsqueeze(1), waveform_lengths
1004
+ )
1005
+
1006
+ (
1007
+ content, content_mask, global_duration_pred, local_duration_pred,
1008
+ length_aligned_content
1009
+ ) = self.encode_content_with_instruction(
1010
+ content, task, device, instruction, instruction_lengths
1011
+ )
1012
+
1013
+ # truncate unused non time aligned duration prediction
1014
+ if is_time_aligned.sum() > 0:
1015
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
1016
+ else:
1017
+ trunc_ta_length = content.size(1)
1018
+
1019
+ # duration loss
1020
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
1021
+ ta_content_mask = content_mask[:, :trunc_ta_length]
1022
+ local_duration_loss = self.get_local_duration_loss(
1023
+ duration,
1024
+ local_duration_pred,
1025
+ ta_content_mask,
1026
+ is_time_aligned,
1027
+ reduce=loss_reduce
1028
+ )
1029
+
1030
+ global_duration_loss = self.get_global_duration_loss(
1031
+ global_duration_pred, latent_mask, reduce=loss_reduce
1032
+ )
1033
+
1034
+ # --------------------------------------------------------------------
1035
+ # prepare latent and noise
1036
+ # --------------------------------------------------------------------
1037
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
1038
+ latent,
1039
+ training = self.training
1040
+ )
1041
+
1042
+ # --------------------------------------------------------------------
1043
+ # duration adapter
1044
+ # --------------------------------------------------------------------
1045
+ if is_time_aligned.sum() == 0 and \
1046
+ duration.size(1) < content_mask.size(1):
1047
+ duration = F.pad(
1048
+ duration, (0, content_mask.size(1) - duration.size(1))
1049
+ )
1050
+ time_aligned_content, _ = self.expand_by_duration(
1051
+ x=content[:, :trunc_ta_length],
1052
+ content_mask=ta_content_mask,
1053
+ local_duration=duration,
1054
+ )
1055
+
1056
+ # --------------------------------------------------------------------
1057
+ # prepare input to the backbone
1058
+ # --------------------------------------------------------------------
1059
+ # TODO compatility for 2D spectrogram VAE
1060
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
1061
+ context, context_mask, time_aligned_content = self.get_backbone_input(
1062
+ latent_length, content, content_mask, time_aligned_content,
1063
+ length_aligned_content, is_time_aligned
1064
+ )
1065
+
1066
+ # --------------------------------------------------------------------
1067
+ # classifier free guidance
1068
+ # --------------------------------------------------------------------
1069
+ if self.training and self.classifier_free_guidance:
1070
+ mask_indices = [
1071
+ k for k in range(len(waveform))
1072
+ if random.random() < self.cfg_drop_ratio
1073
+ ]
1074
+ if len(mask_indices) > 0:
1075
+ context[mask_indices] = 0
1076
+ time_aligned_content[mask_indices] = 0
1077
+
1078
+ pred: torch.Tensor = self.backbone(
1079
+ x=noisy_latent,
1080
+ x_mask=latent_mask,
1081
+ timesteps=timesteps,
1082
+ context=context,
1083
+ context_mask=context_mask,
1084
+ time_aligned_context=time_aligned_content,
1085
+ )
1086
+ pred = pred.transpose(1, self.autoencoder.time_dim)
1087
+ target = target.transpose(1, self.autoencoder.time_dim)
1088
+ diff_loss = F.mse_loss(pred, target, reduction="none")
1089
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
1090
+ return {
1091
+ "diff_loss": diff_loss,
1092
+ "local_duration_loss": local_duration_loss,
1093
+ "global_duration_loss": global_duration_loss,
1094
+ }
1095
+
1096
+ def inference(
1097
+ self,
1098
+ content: list[Any],
1099
+ task: list[str],
1100
+ is_time_aligned: Sequence[bool],
1101
+ instruction: torch.Tensor,
1102
+ instruction_lengths: Sequence[int],
1103
+ num_steps: int = 20,
1104
+ sway_sampling_coef: float | None = -1.0,
1105
+ guidance_scale: float = 3.0,
1106
+ disable_progress: bool = True,
1107
+ use_gt_duration: bool = False,
1108
+ **kwargs
1109
+ ):
1110
+ device = self.dummy_param.device
1111
+ classifier_free_guidance = guidance_scale > 1.0
1112
+
1113
+ (
1114
+ content, content_mask, global_duration_pred, local_duration_pred,
1115
+ length_aligned_content
1116
+ ) = self.encode_content_with_instruction(
1117
+ content, task, device, instruction, instruction_lengths
1118
+ )
1119
+ # print("content std: ", content.std())
1120
+ batch_size = content.size(0)
1121
+
1122
+ # truncate dummy time aligned duration prediction
1123
+ is_time_aligned = torch.as_tensor(is_time_aligned)
1124
+ if is_time_aligned.sum() > 0:
1125
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
1126
+ else:
1127
+ trunc_ta_length = content.size(1)
1128
+
1129
+ # prepare local duration
1130
+ local_duration = self.prepare_local_duration(
1131
+ local_duration_pred, content_mask
1132
+ )
1133
+ local_duration = local_duration[:, :trunc_ta_length]
1134
+ # use ground truth duration
1135
+ if use_gt_duration and "duration" in kwargs:
1136
+ local_duration = torch.as_tensor(kwargs["duration"]).to(device)
1137
+
1138
+ # prepare global duration
1139
+ global_duration = self.prepare_global_duration(
1140
+ global_duration_pred, local_duration, is_time_aligned
1141
+ )
1142
+
1143
+ # --------------------------------------------------------------------
1144
+ # duration adapter
1145
+ # --------------------------------------------------------------------
1146
+ time_aligned_content, latent_mask = self.expand_by_duration(
1147
+ x=content[:, :trunc_ta_length],
1148
+ content_mask=content_mask[:, :trunc_ta_length],
1149
+ local_duration=local_duration,
1150
+ global_duration=global_duration,
1151
+ )
1152
+
1153
+ context, context_mask, time_aligned_content = self.get_backbone_input(
1154
+ target_length=time_aligned_content.size(1),
1155
+ content=content,
1156
+ content_mask=content_mask,
1157
+ time_aligned_content=time_aligned_content,
1158
+ length_aligned_content=length_aligned_content,
1159
+ is_time_aligned=is_time_aligned
1160
+ )
1161
+
1162
+ # --------------------------------------------------------------------
1163
+ # prepare unconditional input
1164
+ # --------------------------------------------------------------------
1165
+ if classifier_free_guidance:
1166
+ uncond_time_aligned_content = torch.zeros_like(
1167
+ time_aligned_content
1168
+ )
1169
+ uncond_context = torch.zeros_like(context)
1170
+ uncond_context_mask = context_mask.detach().clone()
1171
+ time_aligned_content = torch.cat([
1172
+ uncond_time_aligned_content, time_aligned_content
1173
+ ])
1174
+ context = torch.cat([uncond_context, context])
1175
+ context_mask = torch.cat([uncond_context_mask, context_mask])
1176
+ latent_mask = torch.cat([
1177
+ latent_mask, latent_mask.detach().clone()
1178
+ ])
1179
+
1180
+ # --------------------------------------------------------------------
1181
+ # prepare input to the backbone
1182
+ # --------------------------------------------------------------------
1183
+ latent_length = latent_mask.sum(1).max().item()
1184
+ latent_shape = tuple(
1185
+ latent_length if dim is None else dim
1186
+ for dim in self.autoencoder.latent_shape
1187
+ )
1188
+ shape = (batch_size, *latent_shape)
1189
+ latent = randn_tensor(
1190
+ shape, generator=None, device=device, dtype=content.dtype
1191
+ )
1192
+
1193
+ if not sway_sampling_coef:
1194
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
1195
+ else:
1196
+ t = torch.linspace(0, 1, num_steps + 1)
1197
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
1198
+ sigmas = 1 - t
1199
+ timesteps, num_steps = self.retrieve_timesteps(
1200
+ num_steps, device, timesteps=None, sigmas=sigmas
1201
+ )
1202
+ latent = self.iterative_denoise(
1203
+ latent=latent,
1204
+ timesteps=timesteps,
1205
+ num_steps=num_steps,
1206
+ verbose=not disable_progress,
1207
+ cfg=classifier_free_guidance,
1208
+ cfg_scale=guidance_scale,
1209
+ backbone_input={
1210
+ "x_mask": latent_mask,
1211
+ "context": context,
1212
+ "context_mask": context_mask,
1213
+ "time_aligned_context": time_aligned_content,
1214
+ }
1215
+ )
1216
+
1217
+ waveform = self.autoencoder.decode(latent)
1218
+ return waveform
1219
+
1220
+
1221
+ class DoubleContentAudioFlowMatching(DummyContentAudioFlowMatching):
1222
+ def get_backbone_input(
1223
+ self, target_length: int, content: torch.Tensor,
1224
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
1225
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
1226
+ ):
1227
+ # TODO compatility for 2D spectrogram VAE
1228
+ time_aligned_content = trim_or_pad_length(
1229
+ time_aligned_content, target_length, 1
1230
+ )
1231
+ length_aligned_content = trim_or_pad_length(
1232
+ length_aligned_content, target_length, 1
1233
+ )
1234
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
1235
+ # length_aligned_content: from aligned input (f0/energy)
1236
+ time_aligned_content = time_aligned_content + length_aligned_content
1237
+
1238
+ context = content
1239
+ context_mask = content_mask.detach().clone()
1240
+
1241
+ return context, context_mask, time_aligned_content
1242
+
1243
+
1244
+ class HybridContentAudioFlowMatching(DummyContentAudioFlowMatching):
1245
+ def get_backbone_input(
1246
+ self, target_length: int, content: torch.Tensor,
1247
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
1248
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
1249
+ ):
1250
+ # TODO compatility for 2D spectrogram VAE
1251
+ time_aligned_content = trim_or_pad_length(
1252
+ time_aligned_content, target_length, 1
1253
+ )
1254
+ length_aligned_content = trim_or_pad_length(
1255
+ length_aligned_content, target_length, 1
1256
+ )
1257
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
1258
+ # length_aligned_content: from aligned input (f0/energy)
1259
+ time_aligned_content = time_aligned_content + length_aligned_content
1260
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
1261
+ time_aligned_content.dtype
1262
+ )
1263
+
1264
+ context = content
1265
+ context_mask = content_mask.detach().clone()
1266
+
1267
+ return context, context_mask, time_aligned_content
requirements.txt CHANGED
@@ -1,3 +1,149 @@
1
- gradio==4.4.1
2
- transformers==4.31.0
3
- huggingface-hub==0.16.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.0
2
+ accelerate==1.2.1
3
+ alias-free-torch==0.0.6
4
+ annotated-types==0.7.0
5
+ antlr4-python3-runtime==4.9.3
6
+ astunparse==1.6.3
7
+ attrs==22.2.0
8
+ audioread==3.0.1
9
+ av==11.0.0
10
+ bitarray==3.7.1
11
+ boto3==1.38.36
12
+ botocore==1.38.36
13
+ braceexpand==0.1.7
14
+ brotlipy==0.7.0
15
+ click==8.1.8
16
+ colorama==0.4.6
17
+ conda==23.1.0
18
+ conda-build==3.23.3
19
+ contourpy==1.2.0
20
+ cycler==0.12.1
21
+ dcase-util==0.2.20
22
+ diffusers==0.33.1
23
+ dnspython==2.3.0
24
+ docker-pycreds==0.4.0
25
+ einops==0.7.0
26
+ exceptiongroup==1.1.1
27
+ expecttest==0.1.4
28
+ fire==0.7.0
29
+ fonttools==4.47.2
30
+ fsspec==2023.12.2
31
+ ftfy==6.3.1
32
+ future==1.0.0
33
+ gitdb==4.0.12
34
+ GitPython==3.1.44
35
+ grpcio==1.73.0
36
+ h5py==3.10.0
37
+ huggingface-hub==0.30.2
38
+ hydra-core==1.3.2
39
+ hypothesis==6.70.0
40
+ imageio==2.37.0
41
+ importlib_metadata==8.5.0
42
+ iniconfig==2.0.0
43
+ ipdb==0.13.13
44
+ jmespath==1.0.1
45
+ joblib==1.3.2
46
+ kiwisolver==1.4.5
47
+ laion_clap==1.1.7
48
+ lazy-dataset==0.0.14
49
+ lazy_loader==0.4
50
+ librosa==0.10.2
51
+ llvmlite==0.42.0
52
+ lxml==6.0.1
53
+ Markdown==3.8
54
+ matplotlib==3.8.2
55
+ mkl-fft==1.3.1
56
+ mkl-service==2.4.0
57
+ mpmath==1.3.0
58
+ msgpack==1.0.8
59
+ networkx==3.0
60
+ numba==0.59.1
61
+ numpy==1.26.4
62
+ nvidia-cublas-cu12==12.4.5.8
63
+ nvidia-cuda-cupti-cu12==12.4.127
64
+ nvidia-cuda-nvrtc-cu12==12.4.127
65
+ nvidia-cuda-runtime-cu12==12.4.127
66
+ nvidia-cudnn-cu12==9.1.0.70
67
+ nvidia-cufft-cu12==11.2.1.3
68
+ nvidia-curand-cu12==10.3.5.147
69
+ nvidia-cusolver-cu12==11.6.1.9
70
+ nvidia-cusparse-cu12==12.3.1.170
71
+ nvidia-cusparselt-cu12==0.6.2
72
+ nvidia-ml-py==12.575.51
73
+ nvidia-nccl-cu12==2.21.5
74
+ nvidia-nvjitlink-cu12==12.4.127
75
+ nvidia-nvtx-cu12==12.4.127
76
+ omegaconf==2.3.0
77
+ packaging==23.2
78
+ pandas==2.2.0
79
+ pathlib==1.0.1
80
+ pillow==11.3.0
81
+ pip-chill==1.0.3
82
+ platformdirs==4.2.1
83
+ pluggy==1.5.0
84
+ pooch==1.8.1
85
+ portalocker==3.2.0
86
+ prettytable==3.16.0
87
+ progressbar==2.5
88
+ protobuf==5.29.2
89
+ psds-eval==0.5.3
90
+ pydantic==2.10.4
91
+ pydantic_core==2.27.2
92
+ pydot-ng==2.0.0
93
+ pyecharts==2.0.8
94
+ pynvml==12.0.0
95
+ pyparsing==3.1.1
96
+ pytest==8.2.0
97
+ python-dateutil==2.8.2
98
+ python-etcd==0.4.5
99
+ python-magic==0.4.27
100
+ regex==2023.12.25
101
+ resampy==0.4.3
102
+ s3transfer==0.13.0
103
+ sacrebleu==2.5.1
104
+ safetensors==0.5.0
105
+ scikit-image==0.25.2
106
+ scikit-learn==1.4.0
107
+ scipy==1.12.0
108
+ sed-eval==0.2.1
109
+ sed-scores-eval==0.0.0
110
+ sentence-transformers==4.1.0
111
+ sentencepiece==0.2.0
112
+ sentry-sdk==2.19.2
113
+ setproctitle==1.3.4
114
+ simplejson==3.20.1
115
+ smmap==5.0.2
116
+ sortedcontainers==2.4.0
117
+ soundfile==0.12.1
118
+ soxr==0.3.7
119
+ swankit==0.2.3
120
+ swanlab==0.6.3
121
+ sympy==1.13.1
122
+ tabulate==0.9.0
123
+ tensorboard==2.19.0
124
+ tensorboard-data-server==0.7.2
125
+ termcolor==3.1.0
126
+ threadpoolctl==3.2.0
127
+ tifffile==2025.5.10
128
+ timm==0.9.12
129
+ tokenizers==0.21.1
130
+ tomli==2.0.1
131
+ torch==2.6.0
132
+ torchaudio==2.6.0
133
+ torchdata==0.10.1
134
+ torchelastic==0.2.2
135
+ torchlibrosa==0.1.0
136
+ torchtext==0.15.0
137
+ torchvision==0.21.0
138
+ transformers==4.51.3
139
+ triton==3.2.0
140
+ types-dataclasses==0.6.6
141
+ typing_extensions==4.12.2
142
+ tzdata==2023.4
143
+ validators==0.28.1
144
+ wandb==0.19.1
145
+ webdataset==1.0.2
146
+ Werkzeug==3.1.3
147
+ wget==3.2
148
+ wrapt==1.17.2
149
+ zipp==3.21.0
utils/__pycache__/accelerate_utilities.cpython-310.pyc ADDED
Binary file (907 Bytes). View file
 
utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.7 kB). View file
 
utils/__pycache__/diffsinger_utilities.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
utils/__pycache__/general.cpython-310.pyc ADDED
Binary file (2.18 kB). View file
 
utils/__pycache__/logging.cpython-310.pyc ADDED
Binary file (908 Bytes). View file
 
utils/__pycache__/lr_scheduler_utilities.cpython-310.pyc ADDED
Binary file (5.03 kB). View file