Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4aa0f34
1
Parent(s):
9577cb2
update to faster inference
Browse files- app.py +17 -31
- dia/audio.py +27 -104
- dia/config.py +17 -26
- dia/layers.py +106 -337
- dia/model.py +314 -257
- dia/state.py +234 -0
app.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
-
import argparse
|
| 2 |
import tempfile
|
| 3 |
import time
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Optional, Tuple
|
| 6 |
-
import spaces
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import numpy as np
|
|
@@ -12,40 +10,17 @@ import torch
|
|
| 12 |
|
| 13 |
from dia.model import Dia
|
| 14 |
|
| 15 |
-
# --- Global Setup ---
|
| 16 |
-
parser = argparse.ArgumentParser(description="Gradio interface for Nari TTS")
|
| 17 |
-
parser.add_argument(
|
| 18 |
-
"--device", type=str, default=None, help="Force device (e.g., 'cuda', 'mps', 'cpu')"
|
| 19 |
-
)
|
| 20 |
-
parser.add_argument("--share", action="store_true", help="Enable Gradio sharing")
|
| 21 |
-
|
| 22 |
-
args = parser.parse_args()
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Determine device
|
| 26 |
-
if args.device:
|
| 27 |
-
device = torch.device(args.device)
|
| 28 |
-
elif torch.cuda.is_available():
|
| 29 |
-
device = torch.device("cuda")
|
| 30 |
-
# Simplified MPS check for broader compatibility
|
| 31 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 32 |
-
# Basic check is usually sufficient, detailed check can be problematic
|
| 33 |
-
device = torch.device("mps")
|
| 34 |
-
else:
|
| 35 |
-
device = torch.device("cpu")
|
| 36 |
-
|
| 37 |
-
print(f"Using device: {device}")
|
| 38 |
|
| 39 |
# Load Nari model and config
|
| 40 |
print("Loading Nari model...")
|
| 41 |
try:
|
| 42 |
# Use the function from inference.py
|
| 43 |
-
model = Dia.from_pretrained("nari-labs/Dia-1.6B")
|
| 44 |
except Exception as e:
|
| 45 |
print(f"Error loading Nari model: {e}")
|
| 46 |
raise
|
| 47 |
|
| 48 |
-
|
| 49 |
def run_inference(
|
| 50 |
text_input: str,
|
| 51 |
audio_prompt_input: Optional[Tuple[int, np.ndarray]],
|
|
@@ -60,7 +35,7 @@ def run_inference(
|
|
| 60 |
Runs Nari inference using the globally loaded model and provided inputs.
|
| 61 |
Uses temporary files for text and audio prompt compatibility with inference.generate.
|
| 62 |
"""
|
| 63 |
-
|
| 64 |
|
| 65 |
if not text_input or text_input.isspace():
|
| 66 |
raise gr.Error("Text input cannot be empty.")
|
|
@@ -146,10 +121,9 @@ def run_inference(
|
|
| 146 |
cfg_scale=cfg_scale,
|
| 147 |
temperature=temperature,
|
| 148 |
top_p=top_p,
|
| 149 |
-
use_cfg_filter=True,
|
| 150 |
cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
|
| 151 |
use_torch_compile=False, # Keep False for Gradio stability
|
| 152 |
-
|
| 153 |
)
|
| 154 |
|
| 155 |
end_time = time.time()
|
|
@@ -192,6 +166,16 @@ def run_inference(
|
|
| 192 |
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
|
| 193 |
)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
else:
|
| 196 |
print("\nGeneration finished, but no valid tokens were produced.")
|
| 197 |
# Return default silence
|
|
@@ -383,8 +367,10 @@ with gr.Blocks(css=css) as demo:
|
|
| 383 |
else:
|
| 384 |
gr.Markdown("_(No examples configured or example prompt file missing)_")
|
| 385 |
|
| 386 |
-
|
| 387 |
# --- Launch the App ---
|
| 388 |
if __name__ == "__main__":
|
| 389 |
print("Launching Gradio interface...")
|
|
|
|
|
|
|
|
|
|
| 390 |
demo.launch()
|
|
|
|
|
|
|
| 1 |
import tempfile
|
| 2 |
import time
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Optional, Tuple
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import numpy as np
|
|
|
|
| 10 |
|
| 11 |
from dia.model import Dia
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Load Nari model and config
|
| 15 |
print("Loading Nari model...")
|
| 16 |
try:
|
| 17 |
# Use the function from inference.py
|
| 18 |
+
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
|
| 19 |
except Exception as e:
|
| 20 |
print(f"Error loading Nari model: {e}")
|
| 21 |
raise
|
| 22 |
|
| 23 |
+
|
| 24 |
def run_inference(
|
| 25 |
text_input: str,
|
| 26 |
audio_prompt_input: Optional[Tuple[int, np.ndarray]],
|
|
|
|
| 35 |
Runs Nari inference using the globally loaded model and provided inputs.
|
| 36 |
Uses temporary files for text and audio prompt compatibility with inference.generate.
|
| 37 |
"""
|
| 38 |
+
global model, device # Access global model, config, device
|
| 39 |
|
| 40 |
if not text_input or text_input.isspace():
|
| 41 |
raise gr.Error("Text input cannot be empty.")
|
|
|
|
| 121 |
cfg_scale=cfg_scale,
|
| 122 |
temperature=temperature,
|
| 123 |
top_p=top_p,
|
|
|
|
| 124 |
cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
|
| 125 |
use_torch_compile=False, # Keep False for Gradio stability
|
| 126 |
+
audio_prompt=prompt_path_for_generate,
|
| 127 |
)
|
| 128 |
|
| 129 |
end_time = time.time()
|
|
|
|
| 166 |
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
|
| 167 |
)
|
| 168 |
|
| 169 |
+
# Explicitly convert to int16 to prevent Gradio warning
|
| 170 |
+
if (
|
| 171 |
+
output_audio[1].dtype == np.float32
|
| 172 |
+
or output_audio[1].dtype == np.float64
|
| 173 |
+
):
|
| 174 |
+
audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
|
| 175 |
+
audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
|
| 176 |
+
output_audio = (output_sr, audio_for_gradio)
|
| 177 |
+
print("Converted audio to int16 for Gradio output.")
|
| 178 |
+
|
| 179 |
else:
|
| 180 |
print("\nGeneration finished, but no valid tokens were produced.")
|
| 181 |
# Return default silence
|
|
|
|
| 367 |
else:
|
| 368 |
gr.Markdown("_(No examples configured or example prompt file missing)_")
|
| 369 |
|
|
|
|
| 370 |
# --- Launch the App ---
|
| 371 |
if __name__ == "__main__":
|
| 372 |
print("Launching Gradio interface...")
|
| 373 |
+
|
| 374 |
+
# set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
|
| 375 |
+
# use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
|
| 376 |
demo.launch()
|
dia/audio.py
CHANGED
|
@@ -2,10 +2,10 @@ import typing as tp
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
from .config import DataConfig
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
"""
|
| 10 |
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
|
| 11 |
Negative t_idx => BOS; t_idx >= T => PAD.
|
|
@@ -69,7 +69,9 @@ def apply_audio_delay(
|
|
| 69 |
|
| 70 |
# Equivalent of tf.gather_nd using advanced indexing
|
| 71 |
# Ensure indices are long type if not already (build_delay_indices should handle this)
|
| 72 |
-
gathered_flat = audio_BxTxC[
|
|
|
|
|
|
|
| 73 |
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
|
| 74 |
|
| 75 |
# Create masks on the correct device
|
|
@@ -82,65 +84,16 @@ def apply_audio_delay(
|
|
| 82 |
|
| 83 |
# If mask_bos, BOS; else if mask_pad, PAD; else original gather
|
| 84 |
# All tensors should now be on the same device
|
| 85 |
-
result_BxTxC = torch.where(
|
| 86 |
-
|
| 87 |
-
return result_BxTxC
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
@torch.no_grad()
|
| 91 |
-
@torch.inference_mode()
|
| 92 |
-
def audio_to_codebook(
|
| 93 |
-
model,
|
| 94 |
-
input_values,
|
| 95 |
-
data_config: DataConfig,
|
| 96 |
-
padding_mask=None,
|
| 97 |
-
sample_rate=44100,
|
| 98 |
-
):
|
| 99 |
-
"""
|
| 100 |
-
Encodes the input audio waveform into discrete codes.
|
| 101 |
-
|
| 102 |
-
Args:
|
| 103 |
-
model: The model to use for encoding.
|
| 104 |
-
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
| 105 |
-
Float values of the input audio waveform.
|
| 106 |
-
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
| 107 |
-
Padding mask used to pad the `input_values`.
|
| 108 |
-
sample_rate (`int`, *optional*) :
|
| 109 |
-
Signal sampling_rate
|
| 110 |
-
|
| 111 |
-
Returns:
|
| 112 |
-
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
|
| 113 |
-
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
|
| 114 |
-
`codebook` of shape `[batch_size, num_codebooks, frames]`.
|
| 115 |
-
Scale is not used here.
|
| 116 |
-
|
| 117 |
-
"""
|
| 118 |
-
audio_data = model.preprocess(input_values, sample_rate)
|
| 119 |
-
|
| 120 |
-
if padding_mask is None:
|
| 121 |
-
padding_mask = torch.ones_like(input_values).bool()
|
| 122 |
-
|
| 123 |
-
_, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
|
| 124 |
-
seq_length = encoded_frame.shape[2]
|
| 125 |
-
|
| 126 |
-
t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
|
| 127 |
-
B=1,
|
| 128 |
-
T=seq_length,
|
| 129 |
-
C=data_config.channels,
|
| 130 |
-
delay_pattern=data_config.delay_pattern,
|
| 131 |
)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
|
| 135 |
-
pad_value=data_config.audio_pad_value,
|
| 136 |
-
bos_value=data_config.audio_bos_value,
|
| 137 |
-
precomp=(t_idx_BxTxC, indices_BTCx3),
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
return encoded_frame
|
| 141 |
|
| 142 |
|
| 143 |
-
def build_revert_indices(
|
|
|
|
|
|
|
| 144 |
"""
|
| 145 |
Precompute indices for the revert operation using PyTorch.
|
| 146 |
|
|
@@ -162,8 +115,12 @@ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) ->
|
|
| 162 |
t_idx_BT1 + delay_arr.view(1, 1, C),
|
| 163 |
torch.tensor(T - 1, device=device),
|
| 164 |
)
|
| 165 |
-
b_idx_BxTxC = torch.broadcast_to(
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
indices_BTCx3 = torch.stack(
|
| 169 |
[
|
|
@@ -205,15 +162,21 @@ def revert_audio_delay(
|
|
| 205 |
indices_BTCx3 = indices_BTCx3.to(device)
|
| 206 |
|
| 207 |
# Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
|
| 208 |
-
gathered_flat = audio_BxTxC[
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
# Create pad_tensor on the correct device
|
| 212 |
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
| 213 |
# Create T tensor on the correct device for comparison
|
| 214 |
T_tensor = torch.tensor(T, device=device)
|
| 215 |
|
| 216 |
-
result_BxTxC = torch.where(
|
|
|
|
|
|
|
| 217 |
|
| 218 |
return result_BxTxC
|
| 219 |
|
|
@@ -238,43 +201,3 @@ def decode(
|
|
| 238 |
except Exception as e:
|
| 239 |
print(f"Error in decode method: {str(e)}")
|
| 240 |
raise
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
|
| 244 |
-
"""Process a single codebook file to generate audio"""
|
| 245 |
-
# Remove BOS token
|
| 246 |
-
generated_codes = generated_codes[:, 1:]
|
| 247 |
-
|
| 248 |
-
if generated_codes.shape[1] > T:
|
| 249 |
-
generated_codes = generated_codes[:, :T]
|
| 250 |
-
|
| 251 |
-
seq_length = generated_codes.shape[1]
|
| 252 |
-
|
| 253 |
-
# Build revert indices
|
| 254 |
-
t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
|
| 255 |
-
|
| 256 |
-
# Transpose and add batch dimension
|
| 257 |
-
audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
|
| 258 |
-
reverted_codebook = revert_audio_delay(
|
| 259 |
-
audio_BxTxC=audio_BxTxC,
|
| 260 |
-
pad_value=0,
|
| 261 |
-
precomp=(t_idx_BxTxC, indices_BTCx3),
|
| 262 |
-
T=seq_length,
|
| 263 |
-
)
|
| 264 |
-
reverted_codebook = reverted_codebook[:, :-30, :]
|
| 265 |
-
|
| 266 |
-
codebook = reverted_codebook.transpose(1, 2)
|
| 267 |
-
|
| 268 |
-
min_valid_index = 0
|
| 269 |
-
max_valid_index = 1023
|
| 270 |
-
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
|
| 271 |
-
|
| 272 |
-
num_invalid = torch.sum(invalid_mask).item()
|
| 273 |
-
if num_invalid > 0:
|
| 274 |
-
print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
|
| 275 |
-
|
| 276 |
-
# Set invalid values to 0 (modify the tensor in-place)
|
| 277 |
-
codebook[invalid_mask] = 0
|
| 278 |
-
audio_array = decode(model, codebook)
|
| 279 |
-
|
| 280 |
-
return audio_array
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
+
def build_delay_indices(
|
| 7 |
+
B: int, T: int, C: int, delay_pattern: tp.List[int]
|
| 8 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 9 |
"""
|
| 10 |
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
|
| 11 |
Negative t_idx => BOS; t_idx >= T => PAD.
|
|
|
|
| 69 |
|
| 70 |
# Equivalent of tf.gather_nd using advanced indexing
|
| 71 |
# Ensure indices are long type if not already (build_delay_indices should handle this)
|
| 72 |
+
gathered_flat = audio_BxTxC[
|
| 73 |
+
indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
|
| 74 |
+
]
|
| 75 |
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
|
| 76 |
|
| 77 |
# Create masks on the correct device
|
|
|
|
| 84 |
|
| 85 |
# If mask_bos, BOS; else if mask_pad, PAD; else original gather
|
| 86 |
# All tensors should now be on the same device
|
| 87 |
+
result_BxTxC = torch.where(
|
| 88 |
+
mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
+
return result_BxTxC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
+
def build_revert_indices(
|
| 95 |
+
B: int, T: int, C: int, delay_pattern: tp.List[int]
|
| 96 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 97 |
"""
|
| 98 |
Precompute indices for the revert operation using PyTorch.
|
| 99 |
|
|
|
|
| 115 |
t_idx_BT1 + delay_arr.view(1, 1, C),
|
| 116 |
torch.tensor(T - 1, device=device),
|
| 117 |
)
|
| 118 |
+
b_idx_BxTxC = torch.broadcast_to(
|
| 119 |
+
torch.arange(B, device=device).view(B, 1, 1), [B, T, C]
|
| 120 |
+
)
|
| 121 |
+
c_idx_BxTxC = torch.broadcast_to(
|
| 122 |
+
torch.arange(C, device=device).view(1, 1, C), [B, T, C]
|
| 123 |
+
)
|
| 124 |
|
| 125 |
indices_BTCx3 = torch.stack(
|
| 126 |
[
|
|
|
|
| 162 |
indices_BTCx3 = indices_BTCx3.to(device)
|
| 163 |
|
| 164 |
# Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
|
| 165 |
+
gathered_flat = audio_BxTxC[
|
| 166 |
+
indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
|
| 167 |
+
]
|
| 168 |
+
gathered_BxTxC = gathered_flat.view(
|
| 169 |
+
audio_BxTxC.size()
|
| 170 |
+
) # Use .size() for robust reshaping
|
| 171 |
|
| 172 |
# Create pad_tensor on the correct device
|
| 173 |
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
| 174 |
# Create T tensor on the correct device for comparison
|
| 175 |
T_tensor = torch.tensor(T, device=device)
|
| 176 |
|
| 177 |
+
result_BxTxC = torch.where(
|
| 178 |
+
t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC
|
| 179 |
+
) # Changed np.where to torch.where
|
| 180 |
|
| 181 |
return result_BxTxC
|
| 182 |
|
|
|
|
| 201 |
except Exception as e:
|
| 202 |
print(f"Error in decode method: {str(e)}")
|
| 203 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dia/config.py
CHANGED
|
@@ -33,14 +33,20 @@ class DataConfig(BaseModel, frozen=True):
|
|
| 33 |
delay_pattern: List of delay values for each audio channel.
|
| 34 |
"""
|
| 35 |
|
| 36 |
-
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] =
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
channels: int = Field(default=9, gt=0, multiple_of=1)
|
| 39 |
text_pad_value: int = Field(default=0)
|
| 40 |
audio_eos_value: int = Field(default=1024)
|
| 41 |
audio_pad_value: int = Field(default=1025)
|
| 42 |
audio_bos_value: int = Field(default=1026)
|
| 43 |
-
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def __hash__(self) -> int:
|
| 46 |
"""Generate a hash based on all fields of the config."""
|
|
@@ -67,8 +73,6 @@ class EncoderConfig(BaseModel, frozen=True):
|
|
| 67 |
n_hidden: Hidden dimension size in the MLP layers.
|
| 68 |
n_head: Number of attention heads.
|
| 69 |
head_dim: Dimension per attention head.
|
| 70 |
-
mlp_activations: List of activation functions for the MLP layers.
|
| 71 |
-
use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
|
| 72 |
"""
|
| 73 |
|
| 74 |
n_layer: int = Field(gt=0)
|
|
@@ -76,8 +80,6 @@ class EncoderConfig(BaseModel, frozen=True):
|
|
| 76 |
n_hidden: int = Field(gt=0)
|
| 77 |
n_head: int = Field(gt=0)
|
| 78 |
head_dim: int = Field(gt=0)
|
| 79 |
-
mlp_activations: list[str] = Field(default=["silu", "linear"])
|
| 80 |
-
use_pre_norm: bool = Field(default=False)
|
| 81 |
|
| 82 |
|
| 83 |
class DecoderConfig(BaseModel, frozen=True):
|
|
@@ -92,8 +94,6 @@ class DecoderConfig(BaseModel, frozen=True):
|
|
| 92 |
gqa_head_dim: Dimension per query head for grouped-query self-attention.
|
| 93 |
cross_query_heads: Number of query heads for cross-attention.
|
| 94 |
cross_head_dim: Dimension per cross-attention head.
|
| 95 |
-
mlp_activations: List of activation functions for the MLP layers.
|
| 96 |
-
use_pre_norm: Whether to use pre-normalization.
|
| 97 |
"""
|
| 98 |
|
| 99 |
n_layer: int = Field(gt=0)
|
|
@@ -104,8 +104,6 @@ class DecoderConfig(BaseModel, frozen=True):
|
|
| 104 |
gqa_head_dim: int = Field(gt=0)
|
| 105 |
cross_query_heads: int = Field(gt=0)
|
| 106 |
cross_head_dim: int = Field(gt=0)
|
| 107 |
-
mlp_activations: list[str] = Field(default=["silu", "linear"])
|
| 108 |
-
use_pre_norm: bool = Field(default=False)
|
| 109 |
|
| 110 |
|
| 111 |
class ModelConfig(BaseModel, frozen=True):
|
|
@@ -130,24 +128,16 @@ class ModelConfig(BaseModel, frozen=True):
|
|
| 130 |
dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
|
| 131 |
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
|
| 132 |
weight_dtype: str = Field(default="float32", description="Weight precision")
|
| 133 |
-
rope_min_timescale: int = Field(
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
class TrainingConfig(BaseModel, frozen=True):
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
Note: This configuration currently only includes precision settings.
|
| 141 |
-
Other training parameters (like batch size, learning rate, optimizer settings)
|
| 142 |
-
are assumed to be handled externally.
|
| 143 |
-
|
| 144 |
-
Attributes:
|
| 145 |
-
dtype: Data type for activations during training (e.g., "bfloat16", "float32").
|
| 146 |
-
logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
|
| 147 |
-
"""
|
| 148 |
-
|
| 149 |
-
dtype: str = Field(default="bfloat16", description="Activation precision")
|
| 150 |
-
logits_dot_in_fp32: bool = Field(default=False)
|
| 151 |
|
| 152 |
|
| 153 |
class DiaConfig(BaseModel, frozen=True):
|
|
@@ -164,6 +154,7 @@ class DiaConfig(BaseModel, frozen=True):
|
|
| 164 |
|
| 165 |
version: str = Field(default="1.0")
|
| 166 |
model: ModelConfig
|
|
|
|
| 167 |
training: TrainingConfig
|
| 168 |
data: DataConfig
|
| 169 |
|
|
|
|
| 33 |
delay_pattern: List of delay values for each audio channel.
|
| 34 |
"""
|
| 35 |
|
| 36 |
+
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
|
| 37 |
+
Field(gt=0, multiple_of=128)
|
| 38 |
+
)
|
| 39 |
+
audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
|
| 40 |
+
Field(gt=0, multiple_of=128)
|
| 41 |
+
)
|
| 42 |
channels: int = Field(default=9, gt=0, multiple_of=1)
|
| 43 |
text_pad_value: int = Field(default=0)
|
| 44 |
audio_eos_value: int = Field(default=1024)
|
| 45 |
audio_pad_value: int = Field(default=1025)
|
| 46 |
audio_bos_value: int = Field(default=1026)
|
| 47 |
+
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
|
| 48 |
+
default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
|
| 49 |
+
)
|
| 50 |
|
| 51 |
def __hash__(self) -> int:
|
| 52 |
"""Generate a hash based on all fields of the config."""
|
|
|
|
| 73 |
n_hidden: Hidden dimension size in the MLP layers.
|
| 74 |
n_head: Number of attention heads.
|
| 75 |
head_dim: Dimension per attention head.
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
|
| 78 |
n_layer: int = Field(gt=0)
|
|
|
|
| 80 |
n_hidden: int = Field(gt=0)
|
| 81 |
n_head: int = Field(gt=0)
|
| 82 |
head_dim: int = Field(gt=0)
|
|
|
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
class DecoderConfig(BaseModel, frozen=True):
|
|
|
|
| 94 |
gqa_head_dim: Dimension per query head for grouped-query self-attention.
|
| 95 |
cross_query_heads: Number of query heads for cross-attention.
|
| 96 |
cross_head_dim: Dimension per cross-attention head.
|
|
|
|
|
|
|
| 97 |
"""
|
| 98 |
|
| 99 |
n_layer: int = Field(gt=0)
|
|
|
|
| 104 |
gqa_head_dim: int = Field(gt=0)
|
| 105 |
cross_query_heads: int = Field(gt=0)
|
| 106 |
cross_head_dim: int = Field(gt=0)
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
class ModelConfig(BaseModel, frozen=True):
|
|
|
|
| 128 |
dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
|
| 129 |
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
|
| 130 |
weight_dtype: str = Field(default="float32", description="Weight precision")
|
| 131 |
+
rope_min_timescale: int = Field(
|
| 132 |
+
default=1, description="Timescale For global Attention"
|
| 133 |
+
)
|
| 134 |
+
rope_max_timescale: int = Field(
|
| 135 |
+
default=10_000, description="Timescale For global Attention"
|
| 136 |
+
)
|
| 137 |
|
| 138 |
|
| 139 |
class TrainingConfig(BaseModel, frozen=True):
|
| 140 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
class DiaConfig(BaseModel, frozen=True):
|
|
|
|
| 154 |
|
| 155 |
version: str = Field(default="1.0")
|
| 156 |
model: ModelConfig
|
| 157 |
+
# TODO: remove training. this is just for backwards-compatability
|
| 158 |
training: TrainingConfig
|
| 159 |
data: DataConfig
|
| 160 |
|
dia/layers.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
-
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
|
@@ -7,26 +5,13 @@ from torch import Tensor
|
|
| 7 |
from torch.nn import RMSNorm
|
| 8 |
|
| 9 |
from .config import DiaConfig
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
|
| 13 |
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
|
| 14 |
|
| 15 |
|
| 16 |
-
def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
|
| 17 |
-
# Allow None for default behavior
|
| 18 |
-
if dtype_str is None or dtype_str.lower() == "none":
|
| 19 |
-
return None
|
| 20 |
-
if dtype_str == "float32":
|
| 21 |
-
return torch.float32
|
| 22 |
-
elif dtype_str == "float16":
|
| 23 |
-
return torch.float16
|
| 24 |
-
elif dtype_str == "bfloat16":
|
| 25 |
-
return torch.bfloat16
|
| 26 |
-
else:
|
| 27 |
-
raise ValueError(f"Unsupported dtype string: {dtype_str}")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
class DenseGeneral(nn.Module):
|
| 31 |
"""
|
| 32 |
PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
|
|
@@ -50,7 +35,6 @@ class DenseGeneral(nn.Module):
|
|
| 50 |
in_shapes: tuple[int, ...],
|
| 51 |
out_features: tuple[int, ...],
|
| 52 |
axis: tuple[int, ...] = (-1,),
|
| 53 |
-
dtype: torch.dtype | None = None,
|
| 54 |
weight_dtype: torch.dtype | None = None,
|
| 55 |
device: torch.device | None = None,
|
| 56 |
):
|
|
@@ -58,7 +42,6 @@ class DenseGeneral(nn.Module):
|
|
| 58 |
self.in_shapes = in_shapes
|
| 59 |
self.out_features = out_features
|
| 60 |
self.axis = axis
|
| 61 |
-
self.dtype = dtype
|
| 62 |
self.kernel_shape = self.in_shapes + self.out_features
|
| 63 |
|
| 64 |
factory_kwargs = {"device": device, "dtype": weight_dtype}
|
|
@@ -70,95 +53,44 @@ class DenseGeneral(nn.Module):
|
|
| 70 |
kernel_contract_axes = tuple(range(len(norm_axis)))
|
| 71 |
|
| 72 |
output = torch.tensordot(
|
| 73 |
-
inputs.
|
| 74 |
-
self.weight
|
| 75 |
dims=(norm_axis, kernel_contract_axes),
|
| 76 |
).to(inputs.dtype)
|
| 77 |
return output
|
| 78 |
|
| 79 |
|
| 80 |
-
def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
|
| 81 |
-
"""Maps activation string to PyTorch activation function module."""
|
| 82 |
-
if activation_string == "gelu":
|
| 83 |
-
return nn.GELU()
|
| 84 |
-
elif activation_string == "relu":
|
| 85 |
-
return nn.ReLU()
|
| 86 |
-
elif activation_string == "silu" or activation_string == "swish":
|
| 87 |
-
return nn.SiLU()
|
| 88 |
-
elif activation_string == "linear":
|
| 89 |
-
return nn.Identity()
|
| 90 |
-
else:
|
| 91 |
-
raise ValueError(f"Unsupported activation function: {activation_string}")
|
| 92 |
-
|
| 93 |
-
|
| 94 |
class MlpBlock(nn.Module):
|
| 95 |
"""MLP block using DenseGeneral."""
|
| 96 |
|
| 97 |
def __init__(
|
| 98 |
-
self,
|
| 99 |
-
config: DiaConfig,
|
| 100 |
-
embed_dim: int,
|
| 101 |
-
intermediate_dim: int,
|
| 102 |
-
dropout_rate: float,
|
| 103 |
-
activations: list[str] = ["silu", "linear"],
|
| 104 |
-
use_pre_norm: bool = False,
|
| 105 |
):
|
| 106 |
super().__init__()
|
| 107 |
-
self.use_pre_norm = use_pre_norm
|
| 108 |
-
num_activations = len(activations)
|
| 109 |
-
compute_dtype = _str_to_dtype(config.training.dtype)
|
| 110 |
-
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
| 111 |
self.dtype = compute_dtype
|
| 112 |
-
# Assume default device for now, could be passed in config
|
| 113 |
-
|
| 114 |
-
if use_pre_norm:
|
| 115 |
-
self.pre_norm = RMSNorm(
|
| 116 |
-
embed_dim,
|
| 117 |
-
eps=config.model.normalization_layer_epsilon,
|
| 118 |
-
dtype=torch.float32,
|
| 119 |
-
)
|
| 120 |
|
| 121 |
self.wi_fused = DenseGeneral(
|
| 122 |
in_shapes=(embed_dim,),
|
| 123 |
-
out_features=(
|
| 124 |
-
num_activations,
|
| 125 |
-
intermediate_dim,
|
| 126 |
-
),
|
| 127 |
axis=(-1,),
|
| 128 |
-
|
| 129 |
-
weight_dtype=weight_dtype,
|
| 130 |
)
|
| 131 |
|
| 132 |
-
self.activation_fn_0 = get_activation_fn(activations[0]) # silu
|
| 133 |
-
self.activation_fn_1 = get_activation_fn(activations[1]) # linear
|
| 134 |
-
|
| 135 |
-
self.dropout = nn.Dropout(dropout_rate)
|
| 136 |
-
|
| 137 |
-
# Output layer using DenseGeneral
|
| 138 |
self.wo = DenseGeneral(
|
| 139 |
in_shapes=(intermediate_dim,),
|
| 140 |
out_features=(embed_dim,),
|
| 141 |
axis=(-1,),
|
| 142 |
-
|
| 143 |
-
weight_dtype=weight_dtype,
|
| 144 |
)
|
| 145 |
|
| 146 |
-
def forward(self, x: torch.Tensor
|
| 147 |
"""Forward pass."""
|
| 148 |
-
if self.use_pre_norm and hasattr(self, "pre_norm"):
|
| 149 |
-
x = self.pre_norm(x)
|
| 150 |
-
|
| 151 |
fused_x = self.wi_fused(x)
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
gate = self.activation_fn_0(gate_input)
|
| 157 |
-
up = self.activation_fn_1(up_input)
|
| 158 |
-
hidden = torch.mul(gate, up).to(self.dtype)
|
| 159 |
|
| 160 |
-
|
| 161 |
-
hidden = self.dropout(hidden)
|
| 162 |
|
| 163 |
output = self.wo(hidden)
|
| 164 |
return output
|
|
@@ -207,37 +139,6 @@ class RotaryEmbedding(nn.Module):
|
|
| 207 |
return torch.cat((first_part, second_part), dim=-1)
|
| 208 |
|
| 209 |
|
| 210 |
-
class KVCache:
|
| 211 |
-
def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
|
| 212 |
-
self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k
|
| 213 |
-
self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v
|
| 214 |
-
self.current_idx = 0
|
| 215 |
-
self.max_len = max_len
|
| 216 |
-
|
| 217 |
-
def get_kv_for_attention(self, current_k, current_v):
|
| 218 |
-
if self.current_idx == 0:
|
| 219 |
-
return current_k, current_v
|
| 220 |
-
else:
|
| 221 |
-
past_k = self.k[:, :, : self.current_idx, :]
|
| 222 |
-
past_v = self.v[:, :, : self.current_idx, :]
|
| 223 |
-
attn_k = torch.cat((past_k, current_k), dim=2)
|
| 224 |
-
attn_v = torch.cat((past_v, current_v), dim=2)
|
| 225 |
-
return attn_k, attn_v
|
| 226 |
-
|
| 227 |
-
def update_cache(self, k, v):
|
| 228 |
-
assert self.current_idx < self.max_len
|
| 229 |
-
self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
|
| 230 |
-
self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
|
| 231 |
-
self.current_idx += 1
|
| 232 |
-
|
| 233 |
-
def prefill_kv(self, k, v):
|
| 234 |
-
prefill_len = k.shape[2]
|
| 235 |
-
assert prefill_len <= self.max_len
|
| 236 |
-
self.k[:, :, :prefill_len, :] = k
|
| 237 |
-
self.v[:, :, :prefill_len, :] = v
|
| 238 |
-
self.current_idx = prefill_len
|
| 239 |
-
|
| 240 |
-
|
| 241 |
class Attention(nn.Module):
|
| 242 |
"""Attention using DenseGeneral."""
|
| 243 |
|
|
@@ -249,7 +150,7 @@ class Attention(nn.Module):
|
|
| 249 |
num_query_heads: int,
|
| 250 |
num_kv_heads: int,
|
| 251 |
head_dim: int,
|
| 252 |
-
|
| 253 |
is_cross_attn: bool = False,
|
| 254 |
out_embed_dim: int | None = None,
|
| 255 |
):
|
|
@@ -258,13 +159,12 @@ class Attention(nn.Module):
|
|
| 258 |
self.num_kv_heads = num_kv_heads
|
| 259 |
self.head_dim = head_dim
|
| 260 |
self.is_cross_attn = is_cross_attn
|
| 261 |
-
self.dropout_rate = dropout_rate
|
| 262 |
-
compute_dtype = _str_to_dtype(config.training.dtype)
|
| 263 |
-
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
| 264 |
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
|
| 265 |
self.projected_query_dim = num_query_heads * head_dim
|
| 266 |
if num_query_heads % num_kv_heads != 0:
|
| 267 |
-
raise ValueError(
|
|
|
|
|
|
|
| 268 |
self.num_gqa_groups = num_query_heads // num_kv_heads
|
| 269 |
|
| 270 |
# --- Projection Layers using DenseGeneral ---
|
|
@@ -272,29 +172,25 @@ class Attention(nn.Module):
|
|
| 272 |
in_shapes=(q_embed_dim,),
|
| 273 |
out_features=(num_query_heads, head_dim),
|
| 274 |
axis=(-1,),
|
| 275 |
-
|
| 276 |
-
weight_dtype=weight_dtype,
|
| 277 |
)
|
| 278 |
self.k_proj = DenseGeneral(
|
| 279 |
in_shapes=(kv_embed_dim,),
|
| 280 |
out_features=(num_kv_heads, head_dim),
|
| 281 |
axis=(-1,),
|
| 282 |
-
|
| 283 |
-
weight_dtype=weight_dtype,
|
| 284 |
)
|
| 285 |
self.v_proj = DenseGeneral(
|
| 286 |
in_shapes=(kv_embed_dim,),
|
| 287 |
out_features=(num_kv_heads, head_dim),
|
| 288 |
axis=(-1,),
|
| 289 |
-
|
| 290 |
-
weight_dtype=weight_dtype,
|
| 291 |
)
|
| 292 |
self.o_proj = DenseGeneral(
|
| 293 |
in_shapes=(num_query_heads, head_dim),
|
| 294 |
out_features=(self.output_dim,),
|
| 295 |
axis=(-2, -1),
|
| 296 |
-
|
| 297 |
-
weight_dtype=weight_dtype,
|
| 298 |
)
|
| 299 |
|
| 300 |
# --- Rotary Embedding ---
|
|
@@ -311,10 +207,11 @@ class Attention(nn.Module):
|
|
| 311 |
Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
|
| 312 |
q_positions: torch.Tensor, # (B, T)
|
| 313 |
kv_positions: torch.Tensor | None = None, # (B, S)
|
| 314 |
-
|
| 315 |
-
|
| 316 |
cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
|
| 317 |
-
prefill: bool = False,
|
|
|
|
| 318 |
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
| 319 |
"""
|
| 320 |
Performs attention calculation with optional KV caching.
|
|
@@ -324,7 +221,6 @@ class Attention(nn.Module):
|
|
| 324 |
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
|
| 325 |
q_positions: Positions for queries (B, T).
|
| 326 |
kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
|
| 327 |
-
deterministic: If True, disable dropout.
|
| 328 |
attn_mask: Attention mask.
|
| 329 |
cache: KVCache.
|
| 330 |
prefill: If True, use prefill mode.
|
|
@@ -342,72 +238,51 @@ class Attention(nn.Module):
|
|
| 342 |
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
|
| 343 |
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
|
| 344 |
|
| 345 |
-
# Input values into attention calculation
|
| 346 |
attn_k: torch.Tensor | None = None
|
| 347 |
attn_v: torch.Tensor | None = None
|
| 348 |
-
new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
|
| 349 |
|
| 350 |
-
# Decoder Cross Attention
|
| 351 |
if self.is_cross_attn:
|
| 352 |
-
# Directly use cache (no need to check index)
|
| 353 |
attn_k, attn_v = cache.k, cache.v
|
| 354 |
-
if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads:
|
| 355 |
-
raise ValueError(
|
| 356 |
-
f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
|
| 357 |
-
f"does not match num_query_heads ({self.num_query_heads}). "
|
| 358 |
-
"Cache should be pre-repeated for GQA."
|
| 359 |
-
)
|
| 360 |
-
# Self Attention
|
| 361 |
else:
|
| 362 |
Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
|
| 363 |
Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
|
| 364 |
-
Xk_BxSxKxH = self.rotary_emb(
|
|
|
|
|
|
|
| 365 |
|
| 366 |
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
| 367 |
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
| 368 |
-
# S=1 for Decode Step
|
| 369 |
-
|
| 370 |
-
if self.num_gqa_groups > 1:
|
| 371 |
-
Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
|
| 372 |
-
Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
|
| 373 |
-
else:
|
| 374 |
-
Xk_BxNxSxH = Xk_BxKxSxH
|
| 375 |
-
Xv_BxNxSxH = Xv_BxKxSxH
|
| 376 |
|
| 377 |
-
# Encoder Self Attention
|
| 378 |
if cache is None:
|
| 379 |
-
attn_k =
|
| 380 |
-
attn_v =
|
| 381 |
-
# Decoder Self Attention
|
| 382 |
else:
|
| 383 |
-
# In prefill mode, we fill in cache until prefill length
|
| 384 |
if prefill:
|
| 385 |
-
attn_k, attn_v =
|
| 386 |
-
cache.
|
| 387 |
-
# In decode step, we add current K/V to cache step by step
|
| 388 |
else:
|
| 389 |
-
|
| 390 |
-
attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
|
| 391 |
|
| 392 |
attn_output = F.scaled_dot_product_attention(
|
| 393 |
Xq_BxNxTxH,
|
| 394 |
attn_k,
|
| 395 |
attn_v,
|
| 396 |
attn_mask=attn_mask,
|
| 397 |
-
dropout_p=self.dropout_rate if not deterministic else 0.0,
|
| 398 |
scale=1.0,
|
|
|
|
|
|
|
| 399 |
)
|
| 400 |
|
| 401 |
attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
|
| 402 |
output = self.o_proj(attn_output)
|
| 403 |
|
| 404 |
-
return output.to(original_dtype)
|
| 405 |
|
| 406 |
|
| 407 |
class EncoderLayer(nn.Module):
|
| 408 |
"""Transformer Encoder Layer using DenseGeneral."""
|
| 409 |
|
| 410 |
-
def __init__(self, config: DiaConfig):
|
| 411 |
super().__init__()
|
| 412 |
self.config = config
|
| 413 |
model_config = config.model
|
|
@@ -420,13 +295,13 @@ class EncoderLayer(nn.Module):
|
|
| 420 |
dtype=torch.float32,
|
| 421 |
)
|
| 422 |
self.self_attention = Attention(
|
| 423 |
-
config
|
| 424 |
q_embed_dim=embed_dim,
|
| 425 |
kv_embed_dim=embed_dim,
|
| 426 |
num_query_heads=enc_config.n_head,
|
| 427 |
num_kv_heads=enc_config.n_head,
|
| 428 |
head_dim=enc_config.head_dim,
|
| 429 |
-
|
| 430 |
is_cross_attn=False,
|
| 431 |
out_embed_dim=embed_dim,
|
| 432 |
)
|
|
@@ -436,62 +311,52 @@ class EncoderLayer(nn.Module):
|
|
| 436 |
dtype=torch.float32,
|
| 437 |
)
|
| 438 |
self.mlp = MlpBlock(
|
| 439 |
-
config=config,
|
| 440 |
embed_dim=embed_dim,
|
| 441 |
intermediate_dim=enc_config.n_hidden,
|
| 442 |
-
|
| 443 |
-
dropout_rate=model_config.dropout,
|
| 444 |
-
use_pre_norm=enc_config.use_pre_norm,
|
| 445 |
)
|
| 446 |
-
self.dropout = nn.Dropout(model_config.dropout)
|
| 447 |
|
| 448 |
def forward(
|
| 449 |
self,
|
| 450 |
x: torch.Tensor,
|
| 451 |
-
|
| 452 |
-
deterministic: bool = True,
|
| 453 |
-
attn_mask: torch.Tensor | None = None,
|
| 454 |
) -> torch.Tensor:
|
| 455 |
residual = x
|
| 456 |
x_norm = self.pre_sa_norm(x)
|
| 457 |
-
|
| 458 |
-
sa_out, _ = self.self_attention(
|
| 459 |
Xq=x_norm,
|
| 460 |
Xkv=x_norm,
|
| 461 |
-
q_positions=
|
| 462 |
-
kv_positions=
|
| 463 |
-
|
| 464 |
-
attn_mask=attn_mask,
|
| 465 |
)
|
| 466 |
x = residual + sa_out
|
| 467 |
|
| 468 |
residual = x
|
| 469 |
x_norm = self.post_sa_norm(x)
|
| 470 |
-
mlp_out = self.mlp(x_norm
|
| 471 |
x = residual + mlp_out
|
| 472 |
|
| 473 |
-
if not deterministic:
|
| 474 |
-
x = self.dropout(x)
|
| 475 |
return x
|
| 476 |
|
| 477 |
|
| 478 |
class Encoder(nn.Module):
|
| 479 |
"""Transformer Encoder Stack using DenseGeneral."""
|
| 480 |
|
| 481 |
-
def __init__(self, config: DiaConfig):
|
| 482 |
super().__init__()
|
| 483 |
self.config = config
|
| 484 |
model_config = config.model
|
| 485 |
enc_config = config.model.encoder
|
| 486 |
-
compute_dtype = _str_to_dtype(config.training.dtype)
|
| 487 |
|
| 488 |
self.embedding = nn.Embedding(
|
| 489 |
model_config.src_vocab_size,
|
| 490 |
enc_config.n_embd,
|
| 491 |
dtype=compute_dtype,
|
| 492 |
)
|
| 493 |
-
self.
|
| 494 |
-
|
|
|
|
| 495 |
self.norm = RMSNorm(
|
| 496 |
enc_config.n_embd,
|
| 497 |
eps=model_config.normalization_layer_epsilon,
|
|
@@ -501,32 +366,21 @@ class Encoder(nn.Module):
|
|
| 501 |
def forward(
|
| 502 |
self,
|
| 503 |
x_ids: torch.Tensor,
|
| 504 |
-
|
| 505 |
-
deterministic: bool = True,
|
| 506 |
-
attn_mask: torch.Tensor | None = None,
|
| 507 |
) -> torch.Tensor:
|
| 508 |
x = self.embedding(x_ids)
|
| 509 |
|
| 510 |
-
if not deterministic:
|
| 511 |
-
x = self.dropout(x)
|
| 512 |
-
|
| 513 |
for layer in self.layers:
|
| 514 |
-
x = layer(
|
| 515 |
-
|
| 516 |
-
src_positions=src_positions,
|
| 517 |
-
deterministic=deterministic,
|
| 518 |
-
attn_mask=attn_mask,
|
| 519 |
-
)
|
| 520 |
x = self.norm(x)
|
| 521 |
-
if not deterministic:
|
| 522 |
-
x = self.dropout(x)
|
| 523 |
return x
|
| 524 |
|
| 525 |
|
| 526 |
class DecoderLayer(nn.Module):
|
| 527 |
"""Transformer Decoder Layer using DenseGeneral."""
|
| 528 |
|
| 529 |
-
def __init__(self, config: DiaConfig):
|
| 530 |
super().__init__()
|
| 531 |
self.config = config
|
| 532 |
model_config = config.model
|
|
@@ -554,13 +408,13 @@ class DecoderLayer(nn.Module):
|
|
| 554 |
|
| 555 |
# Self-Attention (GQA) with Causal Masking
|
| 556 |
self.self_attention = Attention(
|
| 557 |
-
config
|
| 558 |
q_embed_dim=dec_embed_dim,
|
| 559 |
kv_embed_dim=dec_embed_dim,
|
| 560 |
num_query_heads=dec_config.gqa_query_heads,
|
| 561 |
num_kv_heads=dec_config.kv_heads,
|
| 562 |
head_dim=dec_config.gqa_head_dim,
|
| 563 |
-
|
| 564 |
is_cross_attn=False,
|
| 565 |
out_embed_dim=dec_embed_dim,
|
| 566 |
)
|
|
@@ -572,116 +426,105 @@ class DecoderLayer(nn.Module):
|
|
| 572 |
num_query_heads=dec_config.cross_query_heads,
|
| 573 |
num_kv_heads=dec_config.cross_query_heads,
|
| 574 |
head_dim=dec_config.cross_head_dim,
|
| 575 |
-
|
| 576 |
is_cross_attn=True,
|
| 577 |
out_embed_dim=dec_embed_dim,
|
| 578 |
)
|
| 579 |
# MLP
|
| 580 |
self.mlp = MlpBlock(
|
| 581 |
-
config=config,
|
| 582 |
embed_dim=dec_embed_dim,
|
| 583 |
intermediate_dim=dec_config.n_hidden,
|
| 584 |
-
|
| 585 |
-
dropout_rate=model_config.dropout,
|
| 586 |
-
use_pre_norm=dec_config.use_pre_norm,
|
| 587 |
)
|
| 588 |
|
| 589 |
def forward(
|
| 590 |
self,
|
| 591 |
x: torch.Tensor,
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
deterministic: bool,
|
| 596 |
-
self_attn_mask: torch.Tensor,
|
| 597 |
-
cross_attn_mask: torch.Tensor,
|
| 598 |
-
self_attn_cache: KVCache,
|
| 599 |
-
cross_attn_cache: KVCache,
|
| 600 |
prefill: bool = False,
|
| 601 |
) -> torch.Tensor:
|
| 602 |
residual = x
|
| 603 |
x_norm = self.pre_sa_norm(x)
|
| 604 |
|
| 605 |
-
sa_out
|
| 606 |
Xq=x_norm, # (2, 1, D)
|
| 607 |
Xkv=x_norm, # (2, 1, D)
|
| 608 |
-
q_positions=
|
| 609 |
-
kv_positions=
|
| 610 |
-
|
| 611 |
-
attn_mask=self_attn_mask, # (2, 1, 1, S_max)
|
| 612 |
cache=self_attn_cache,
|
| 613 |
prefill=prefill,
|
|
|
|
| 614 |
)
|
| 615 |
|
| 616 |
x = residual + sa_out
|
| 617 |
|
| 618 |
-
# 2. Cross-Attention
|
| 619 |
residual = x
|
| 620 |
x_norm = self.pre_ca_norm(x)
|
| 621 |
-
ca_out
|
| 622 |
Xq=x_norm,
|
| 623 |
-
Xkv=
|
| 624 |
-
q_positions=
|
| 625 |
-
kv_positions=
|
| 626 |
-
|
| 627 |
-
attn_mask=cross_attn_mask,
|
| 628 |
cache=cross_attn_cache,
|
| 629 |
)
|
| 630 |
x = residual + ca_out
|
| 631 |
|
| 632 |
-
# 3. MLP
|
| 633 |
residual = x
|
| 634 |
x_norm = self.pre_mlp_norm(x)
|
| 635 |
-
mlp_out = self.mlp(x_norm
|
| 636 |
x = residual + mlp_out
|
| 637 |
|
| 638 |
-
return x
|
| 639 |
|
| 640 |
|
| 641 |
class Decoder(nn.Module):
|
| 642 |
"""Transformer Decoder Stack using DenseGeneral."""
|
| 643 |
|
| 644 |
-
def __init__(self, config: DiaConfig):
|
| 645 |
super().__init__()
|
| 646 |
self.config = config
|
| 647 |
model_config = config.model
|
| 648 |
dec_config = config.model.decoder
|
| 649 |
-
train_config = config.training
|
| 650 |
data_config = config.data
|
| 651 |
-
compute_dtype = _str_to_dtype(config.training.dtype)
|
| 652 |
-
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
| 653 |
self.num_channels = data_config.channels
|
| 654 |
self.num_layers = dec_config.n_layer
|
| 655 |
|
| 656 |
self.embeddings = nn.ModuleList(
|
| 657 |
[
|
| 658 |
-
nn.Embedding(
|
|
|
|
|
|
|
| 659 |
for _ in range(self.num_channels)
|
| 660 |
]
|
| 661 |
)
|
| 662 |
-
self.
|
| 663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
self.norm = RMSNorm(
|
| 665 |
dec_config.n_embd,
|
| 666 |
eps=model_config.normalization_layer_epsilon,
|
| 667 |
dtype=torch.float32,
|
| 668 |
)
|
| 669 |
|
| 670 |
-
# Final Logits Projection using DenseGeneral
|
| 671 |
self.logits_dense = DenseGeneral(
|
| 672 |
in_shapes=(dec_config.n_embd,),
|
| 673 |
out_features=(self.num_channels, model_config.tgt_vocab_size),
|
| 674 |
axis=(-1,),
|
| 675 |
-
|
| 676 |
-
weight_dtype=weight_dtype,
|
| 677 |
)
|
| 678 |
-
self.logits_in_fp32 = train_config.logits_dot_in_fp32
|
| 679 |
|
| 680 |
-
def
|
| 681 |
self,
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
src_positions: torch.Tensor | None, # (B, S)
|
| 685 |
) -> list[KVCache]:
|
| 686 |
"""
|
| 687 |
Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
|
|
@@ -690,35 +533,21 @@ class Decoder(nn.Module):
|
|
| 690 |
|
| 691 |
for layer in self.layers:
|
| 692 |
cross_attn_module = layer.cross_attention
|
| 693 |
-
k_proj = cross_attn_module.k_proj(
|
| 694 |
-
v_proj = cross_attn_module.v_proj(
|
| 695 |
|
| 696 |
-
k_proj = cross_attn_module.rotary_emb(k_proj, position=
|
| 697 |
k = k_proj.transpose(1, 2)
|
| 698 |
v = v_proj.transpose(1, 2)
|
| 699 |
|
| 700 |
-
per_layer_kv_cache.append(
|
| 701 |
-
KVCache(
|
| 702 |
-
cross_attn_module.num_kv_heads,
|
| 703 |
-
max_len,
|
| 704 |
-
cross_attn_module.head_dim,
|
| 705 |
-
k.device,
|
| 706 |
-
k=k,
|
| 707 |
-
v=v,
|
| 708 |
-
)
|
| 709 |
-
)
|
| 710 |
|
| 711 |
return per_layer_kv_cache
|
| 712 |
|
| 713 |
def decode_step(
|
| 714 |
self,
|
| 715 |
tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
|
| 716 |
-
|
| 717 |
-
encoder_out: torch.Tensor, # [B, S, E]
|
| 718 |
-
self_attn_mask: Any, # None
|
| 719 |
-
cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
|
| 720 |
-
self_attention_cache: list[KVCache],
|
| 721 |
-
cross_attention_cache: list[KVCache],
|
| 722 |
) -> torch.Tensor:
|
| 723 |
"""
|
| 724 |
Performs a single decoding step, managing KV caches layer by layer.
|
|
@@ -727,7 +556,6 @@ class Decoder(nn.Module):
|
|
| 727 |
A tuple containing:
|
| 728 |
- logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
|
| 729 |
"""
|
| 730 |
-
assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern"
|
| 731 |
|
| 732 |
x = None
|
| 733 |
for i in range(self.num_channels):
|
|
@@ -735,40 +563,23 @@ class Decoder(nn.Module):
|
|
| 735 |
channel_embed = self.embeddings[i](channel_tokens)
|
| 736 |
x = channel_embed if x is None else x + channel_embed
|
| 737 |
|
| 738 |
-
new_cache = []
|
| 739 |
-
|
| 740 |
for i, layer in enumerate(self.layers):
|
| 741 |
-
self_cache =
|
| 742 |
-
cross_cache =
|
| 743 |
-
x
|
| 744 |
x, # (2, 1, D)
|
| 745 |
-
|
| 746 |
-
src_positions=None, # CA KV is already computed
|
| 747 |
-
tgt_positions=tgt_pos_Bx1, # (2, 1)
|
| 748 |
-
deterministic=True,
|
| 749 |
-
self_attn_mask=None,
|
| 750 |
-
cross_attn_mask=cross_attn_mask,
|
| 751 |
self_attn_cache=self_cache,
|
| 752 |
cross_attn_cache=cross_cache,
|
| 753 |
)
|
| 754 |
-
new_cache.append(new_kv_cache)
|
| 755 |
|
| 756 |
x = self.norm(x)
|
| 757 |
logits_Bx1xCxV = self.logits_dense(x)
|
| 758 |
|
| 759 |
-
return logits_Bx1xCxV.to(torch.float32)
|
| 760 |
|
| 761 |
def forward(
|
| 762 |
-
self,
|
| 763 |
-
tgt_ids_BxTxC: torch.Tensor,
|
| 764 |
-
encoder_out: torch.Tensor,
|
| 765 |
-
tgt_positions: torch.Tensor,
|
| 766 |
-
src_positions: torch.Tensor,
|
| 767 |
-
deterministic: bool,
|
| 768 |
-
self_attn_mask: torch.Tensor,
|
| 769 |
-
cross_attn_mask: torch.Tensor,
|
| 770 |
-
self_attention_cache: list[KVCache],
|
| 771 |
-
cross_attention_cache: list[KVCache],
|
| 772 |
) -> torch.Tensor:
|
| 773 |
"""
|
| 774 |
Forward pass for the Decoder stack, managing KV caches.
|
|
@@ -778,7 +589,6 @@ class Decoder(nn.Module):
|
|
| 778 |
encoder_out: Output from the encoder (B, S, E).
|
| 779 |
tgt_positions: Positions for target sequence (B, T).
|
| 780 |
src_positions: Positions for source sequence (B, S).
|
| 781 |
-
deterministic: Disable dropout if True.
|
| 782 |
self_attn_mask: Mask for self-attention.
|
| 783 |
cross_attn_mask: Mask for cross-attention.
|
| 784 |
past_key_values: List containing the self-attention KV cache for each layer
|
|
@@ -804,20 +614,14 @@ class Decoder(nn.Module):
|
|
| 804 |
channel_embed = self.embeddings[i](channel_tokens)
|
| 805 |
x = channel_embed if x is None else x + channel_embed
|
| 806 |
|
| 807 |
-
if not deterministic:
|
| 808 |
-
x = self.dropout(x)
|
| 809 |
-
|
| 810 |
for i, layer in enumerate(self.layers):
|
| 811 |
-
|
|
|
|
|
|
|
| 812 |
x,
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
deterministic=deterministic,
|
| 817 |
-
self_attn_mask=self_attn_mask,
|
| 818 |
-
cross_attn_mask=cross_attn_mask,
|
| 819 |
-
self_attn_cache=self_attention_cache[i],
|
| 820 |
-
cross_attn_cache=cross_attention_cache[i],
|
| 821 |
prefill=True,
|
| 822 |
)
|
| 823 |
|
|
@@ -831,43 +635,8 @@ class Decoder(nn.Module):
|
|
| 831 |
class DiaModel(nn.Module):
|
| 832 |
"""PyTorch Dia Model using DenseGeneral."""
|
| 833 |
|
| 834 |
-
def __init__(self, config: DiaConfig):
|
| 835 |
super().__init__()
|
| 836 |
self.config = config
|
| 837 |
-
self.encoder = Encoder(config)
|
| 838 |
-
self.decoder = Decoder(config)
|
| 839 |
-
|
| 840 |
-
def forward(
|
| 841 |
-
self,
|
| 842 |
-
src_BxS: torch.Tensor,
|
| 843 |
-
tgt_BxTxC: torch.Tensor,
|
| 844 |
-
src_positions: torch.Tensor | None = None,
|
| 845 |
-
tgt_positions: torch.Tensor | None = None,
|
| 846 |
-
enc_self_attn_mask: torch.Tensor | None = None,
|
| 847 |
-
dec_self_attn_mask: torch.Tensor | None = None,
|
| 848 |
-
dec_cross_attn_mask: torch.Tensor | None = None,
|
| 849 |
-
enable_dropout: bool = True,
|
| 850 |
-
):
|
| 851 |
-
deterministic = not enable_dropout
|
| 852 |
-
|
| 853 |
-
# --- Encoder Pass ---
|
| 854 |
-
encoder_out = self.encoder(
|
| 855 |
-
x_ids=src_BxS,
|
| 856 |
-
src_positions=src_positions,
|
| 857 |
-
deterministic=deterministic,
|
| 858 |
-
attn_mask=enc_self_attn_mask,
|
| 859 |
-
)
|
| 860 |
-
|
| 861 |
-
# --- Decoder Pass ---
|
| 862 |
-
logits, _ = self.decoder(
|
| 863 |
-
tgt_ids_BxTxC=tgt_BxTxC,
|
| 864 |
-
encoder_out=encoder_out,
|
| 865 |
-
tgt_positions=tgt_positions,
|
| 866 |
-
src_positions=src_positions,
|
| 867 |
-
deterministic=deterministic,
|
| 868 |
-
self_attn_mask=dec_self_attn_mask,
|
| 869 |
-
cross_attn_mask=dec_cross_attn_mask,
|
| 870 |
-
precomputed_cross_attn_kv=None,
|
| 871 |
-
)
|
| 872 |
-
|
| 873 |
-
return logits
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
|
|
|
| 5 |
from torch.nn import RMSNorm
|
| 6 |
|
| 7 |
from .config import DiaConfig
|
| 8 |
+
from .state import DecoderInferenceState, EncoderInferenceState, KVCache
|
| 9 |
|
| 10 |
|
| 11 |
def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
|
| 12 |
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
class DenseGeneral(nn.Module):
|
| 16 |
"""
|
| 17 |
PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
|
|
|
|
| 35 |
in_shapes: tuple[int, ...],
|
| 36 |
out_features: tuple[int, ...],
|
| 37 |
axis: tuple[int, ...] = (-1,),
|
|
|
|
| 38 |
weight_dtype: torch.dtype | None = None,
|
| 39 |
device: torch.device | None = None,
|
| 40 |
):
|
|
|
|
| 42 |
self.in_shapes = in_shapes
|
| 43 |
self.out_features = out_features
|
| 44 |
self.axis = axis
|
|
|
|
| 45 |
self.kernel_shape = self.in_shapes + self.out_features
|
| 46 |
|
| 47 |
factory_kwargs = {"device": device, "dtype": weight_dtype}
|
|
|
|
| 53 |
kernel_contract_axes = tuple(range(len(norm_axis)))
|
| 54 |
|
| 55 |
output = torch.tensordot(
|
| 56 |
+
inputs.to(self.weight.dtype),
|
| 57 |
+
self.weight,
|
| 58 |
dims=(norm_axis, kernel_contract_axes),
|
| 59 |
).to(inputs.dtype)
|
| 60 |
return output
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
class MlpBlock(nn.Module):
|
| 64 |
"""MLP block using DenseGeneral."""
|
| 65 |
|
| 66 |
def __init__(
|
| 67 |
+
self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
):
|
| 69 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
self.dtype = compute_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
self.wi_fused = DenseGeneral(
|
| 73 |
in_shapes=(embed_dim,),
|
| 74 |
+
out_features=(2, intermediate_dim),
|
|
|
|
|
|
|
|
|
|
| 75 |
axis=(-1,),
|
| 76 |
+
weight_dtype=compute_dtype,
|
|
|
|
| 77 |
)
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
self.wo = DenseGeneral(
|
| 80 |
in_shapes=(intermediate_dim,),
|
| 81 |
out_features=(embed_dim,),
|
| 82 |
axis=(-1,),
|
| 83 |
+
weight_dtype=compute_dtype,
|
|
|
|
| 84 |
)
|
| 85 |
|
| 86 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
"""Forward pass."""
|
|
|
|
|
|
|
|
|
|
| 88 |
fused_x = self.wi_fused(x)
|
| 89 |
|
| 90 |
+
gate = fused_x[..., 0, :]
|
| 91 |
+
up = fused_x[..., 1, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
hidden = torch.mul(F.silu(gate), up).to(self.dtype)
|
|
|
|
| 94 |
|
| 95 |
output = self.wo(hidden)
|
| 96 |
return output
|
|
|
|
| 139 |
return torch.cat((first_part, second_part), dim=-1)
|
| 140 |
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
class Attention(nn.Module):
|
| 143 |
"""Attention using DenseGeneral."""
|
| 144 |
|
|
|
|
| 150 |
num_query_heads: int,
|
| 151 |
num_kv_heads: int,
|
| 152 |
head_dim: int,
|
| 153 |
+
compute_dtype: torch.dtype,
|
| 154 |
is_cross_attn: bool = False,
|
| 155 |
out_embed_dim: int | None = None,
|
| 156 |
):
|
|
|
|
| 159 |
self.num_kv_heads = num_kv_heads
|
| 160 |
self.head_dim = head_dim
|
| 161 |
self.is_cross_attn = is_cross_attn
|
|
|
|
|
|
|
|
|
|
| 162 |
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
|
| 163 |
self.projected_query_dim = num_query_heads * head_dim
|
| 164 |
if num_query_heads % num_kv_heads != 0:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
|
| 167 |
+
)
|
| 168 |
self.num_gqa_groups = num_query_heads // num_kv_heads
|
| 169 |
|
| 170 |
# --- Projection Layers using DenseGeneral ---
|
|
|
|
| 172 |
in_shapes=(q_embed_dim,),
|
| 173 |
out_features=(num_query_heads, head_dim),
|
| 174 |
axis=(-1,),
|
| 175 |
+
weight_dtype=compute_dtype,
|
|
|
|
| 176 |
)
|
| 177 |
self.k_proj = DenseGeneral(
|
| 178 |
in_shapes=(kv_embed_dim,),
|
| 179 |
out_features=(num_kv_heads, head_dim),
|
| 180 |
axis=(-1,),
|
| 181 |
+
weight_dtype=compute_dtype,
|
|
|
|
| 182 |
)
|
| 183 |
self.v_proj = DenseGeneral(
|
| 184 |
in_shapes=(kv_embed_dim,),
|
| 185 |
out_features=(num_kv_heads, head_dim),
|
| 186 |
axis=(-1,),
|
| 187 |
+
weight_dtype=compute_dtype,
|
|
|
|
| 188 |
)
|
| 189 |
self.o_proj = DenseGeneral(
|
| 190 |
in_shapes=(num_query_heads, head_dim),
|
| 191 |
out_features=(self.output_dim,),
|
| 192 |
axis=(-2, -1),
|
| 193 |
+
weight_dtype=compute_dtype,
|
|
|
|
| 194 |
)
|
| 195 |
|
| 196 |
# --- Rotary Embedding ---
|
|
|
|
| 207 |
Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
|
| 208 |
q_positions: torch.Tensor, # (B, T)
|
| 209 |
kv_positions: torch.Tensor | None = None, # (B, S)
|
| 210 |
+
attn_mask: torch.Tensor
|
| 211 |
+
| None = None, # None in Decoder Self Attention, Valid mask in Others
|
| 212 |
cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
|
| 213 |
+
prefill: bool = False,
|
| 214 |
+
is_causal: bool = False,
|
| 215 |
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
| 216 |
"""
|
| 217 |
Performs attention calculation with optional KV caching.
|
|
|
|
| 221 |
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
|
| 222 |
q_positions: Positions for queries (B, T).
|
| 223 |
kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
|
|
|
|
| 224 |
attn_mask: Attention mask.
|
| 225 |
cache: KVCache.
|
| 226 |
prefill: If True, use prefill mode.
|
|
|
|
| 238 |
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
|
| 239 |
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
|
| 240 |
|
|
|
|
| 241 |
attn_k: torch.Tensor | None = None
|
| 242 |
attn_v: torch.Tensor | None = None
|
|
|
|
| 243 |
|
|
|
|
| 244 |
if self.is_cross_attn:
|
|
|
|
| 245 |
attn_k, attn_v = cache.k, cache.v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
else:
|
| 247 |
Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
|
| 248 |
Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
|
| 249 |
+
Xk_BxSxKxH = self.rotary_emb(
|
| 250 |
+
Xk_BxSxKxH, position=kv_positions
|
| 251 |
+
) # (B, S, K, H)
|
| 252 |
|
| 253 |
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
| 254 |
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
|
|
|
| 256 |
if cache is None:
|
| 257 |
+
attn_k = Xk_BxKxSxH
|
| 258 |
+
attn_v = Xv_BxKxSxH
|
|
|
|
| 259 |
else:
|
|
|
|
| 260 |
if prefill:
|
| 261 |
+
attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
|
| 262 |
+
cache.prefill(attn_k, attn_v)
|
|
|
|
| 263 |
else:
|
| 264 |
+
attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
|
|
|
|
| 265 |
|
| 266 |
attn_output = F.scaled_dot_product_attention(
|
| 267 |
Xq_BxNxTxH,
|
| 268 |
attn_k,
|
| 269 |
attn_v,
|
| 270 |
attn_mask=attn_mask,
|
|
|
|
| 271 |
scale=1.0,
|
| 272 |
+
enable_gqa=self.num_gqa_groups > 1,
|
| 273 |
+
is_causal=is_causal,
|
| 274 |
)
|
| 275 |
|
| 276 |
attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
|
| 277 |
output = self.o_proj(attn_output)
|
| 278 |
|
| 279 |
+
return output.to(original_dtype)
|
| 280 |
|
| 281 |
|
| 282 |
class EncoderLayer(nn.Module):
|
| 283 |
"""Transformer Encoder Layer using DenseGeneral."""
|
| 284 |
|
| 285 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
| 286 |
super().__init__()
|
| 287 |
self.config = config
|
| 288 |
model_config = config.model
|
|
|
|
| 295 |
dtype=torch.float32,
|
| 296 |
)
|
| 297 |
self.self_attention = Attention(
|
| 298 |
+
config,
|
| 299 |
q_embed_dim=embed_dim,
|
| 300 |
kv_embed_dim=embed_dim,
|
| 301 |
num_query_heads=enc_config.n_head,
|
| 302 |
num_kv_heads=enc_config.n_head,
|
| 303 |
head_dim=enc_config.head_dim,
|
| 304 |
+
compute_dtype=compute_dtype,
|
| 305 |
is_cross_attn=False,
|
| 306 |
out_embed_dim=embed_dim,
|
| 307 |
)
|
|
|
|
| 311 |
dtype=torch.float32,
|
| 312 |
)
|
| 313 |
self.mlp = MlpBlock(
|
|
|
|
| 314 |
embed_dim=embed_dim,
|
| 315 |
intermediate_dim=enc_config.n_hidden,
|
| 316 |
+
compute_dtype=compute_dtype,
|
|
|
|
|
|
|
| 317 |
)
|
|
|
|
| 318 |
|
| 319 |
def forward(
|
| 320 |
self,
|
| 321 |
x: torch.Tensor,
|
| 322 |
+
state: EncoderInferenceState,
|
|
|
|
|
|
|
| 323 |
) -> torch.Tensor:
|
| 324 |
residual = x
|
| 325 |
x_norm = self.pre_sa_norm(x)
|
| 326 |
+
sa_out = self.self_attention(
|
|
|
|
| 327 |
Xq=x_norm,
|
| 328 |
Xkv=x_norm,
|
| 329 |
+
q_positions=state.positions,
|
| 330 |
+
kv_positions=state.positions,
|
| 331 |
+
attn_mask=state.attn_mask,
|
|
|
|
| 332 |
)
|
| 333 |
x = residual + sa_out
|
| 334 |
|
| 335 |
residual = x
|
| 336 |
x_norm = self.post_sa_norm(x)
|
| 337 |
+
mlp_out = self.mlp(x_norm)
|
| 338 |
x = residual + mlp_out
|
| 339 |
|
|
|
|
|
|
|
| 340 |
return x
|
| 341 |
|
| 342 |
|
| 343 |
class Encoder(nn.Module):
|
| 344 |
"""Transformer Encoder Stack using DenseGeneral."""
|
| 345 |
|
| 346 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
| 347 |
super().__init__()
|
| 348 |
self.config = config
|
| 349 |
model_config = config.model
|
| 350 |
enc_config = config.model.encoder
|
|
|
|
| 351 |
|
| 352 |
self.embedding = nn.Embedding(
|
| 353 |
model_config.src_vocab_size,
|
| 354 |
enc_config.n_embd,
|
| 355 |
dtype=compute_dtype,
|
| 356 |
)
|
| 357 |
+
self.layers = nn.ModuleList(
|
| 358 |
+
[EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
|
| 359 |
+
)
|
| 360 |
self.norm = RMSNorm(
|
| 361 |
enc_config.n_embd,
|
| 362 |
eps=model_config.normalization_layer_epsilon,
|
|
|
|
| 366 |
def forward(
|
| 367 |
self,
|
| 368 |
x_ids: torch.Tensor,
|
| 369 |
+
state: EncoderInferenceState,
|
|
|
|
|
|
|
| 370 |
) -> torch.Tensor:
|
| 371 |
x = self.embedding(x_ids)
|
| 372 |
|
|
|
|
|
|
|
|
|
|
| 373 |
for layer in self.layers:
|
| 374 |
+
x = layer(x, state)
|
| 375 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
x = self.norm(x)
|
|
|
|
|
|
|
| 377 |
return x
|
| 378 |
|
| 379 |
|
| 380 |
class DecoderLayer(nn.Module):
|
| 381 |
"""Transformer Decoder Layer using DenseGeneral."""
|
| 382 |
|
| 383 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
| 384 |
super().__init__()
|
| 385 |
self.config = config
|
| 386 |
model_config = config.model
|
|
|
|
| 408 |
|
| 409 |
# Self-Attention (GQA) with Causal Masking
|
| 410 |
self.self_attention = Attention(
|
| 411 |
+
config,
|
| 412 |
q_embed_dim=dec_embed_dim,
|
| 413 |
kv_embed_dim=dec_embed_dim,
|
| 414 |
num_query_heads=dec_config.gqa_query_heads,
|
| 415 |
num_kv_heads=dec_config.kv_heads,
|
| 416 |
head_dim=dec_config.gqa_head_dim,
|
| 417 |
+
compute_dtype=compute_dtype,
|
| 418 |
is_cross_attn=False,
|
| 419 |
out_embed_dim=dec_embed_dim,
|
| 420 |
)
|
|
|
|
| 426 |
num_query_heads=dec_config.cross_query_heads,
|
| 427 |
num_kv_heads=dec_config.cross_query_heads,
|
| 428 |
head_dim=dec_config.cross_head_dim,
|
| 429 |
+
compute_dtype=compute_dtype,
|
| 430 |
is_cross_attn=True,
|
| 431 |
out_embed_dim=dec_embed_dim,
|
| 432 |
)
|
| 433 |
# MLP
|
| 434 |
self.mlp = MlpBlock(
|
|
|
|
| 435 |
embed_dim=dec_embed_dim,
|
| 436 |
intermediate_dim=dec_config.n_hidden,
|
| 437 |
+
compute_dtype=compute_dtype,
|
|
|
|
|
|
|
| 438 |
)
|
| 439 |
|
| 440 |
def forward(
|
| 441 |
self,
|
| 442 |
x: torch.Tensor,
|
| 443 |
+
state: DecoderInferenceState,
|
| 444 |
+
self_attn_cache: KVCache | None = None,
|
| 445 |
+
cross_attn_cache: KVCache | None = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
prefill: bool = False,
|
| 447 |
) -> torch.Tensor:
|
| 448 |
residual = x
|
| 449 |
x_norm = self.pre_sa_norm(x)
|
| 450 |
|
| 451 |
+
sa_out = self.self_attention(
|
| 452 |
Xq=x_norm, # (2, 1, D)
|
| 453 |
Xkv=x_norm, # (2, 1, D)
|
| 454 |
+
q_positions=state.dec_positions, # (2, 1)
|
| 455 |
+
kv_positions=state.dec_positions, # (2, 1)
|
| 456 |
+
attn_mask=None,
|
|
|
|
| 457 |
cache=self_attn_cache,
|
| 458 |
prefill=prefill,
|
| 459 |
+
is_causal=prefill,
|
| 460 |
)
|
| 461 |
|
| 462 |
x = residual + sa_out
|
| 463 |
|
|
|
|
| 464 |
residual = x
|
| 465 |
x_norm = self.pre_ca_norm(x)
|
| 466 |
+
ca_out = self.cross_attention(
|
| 467 |
Xq=x_norm,
|
| 468 |
+
Xkv=state.enc_out,
|
| 469 |
+
q_positions=state.dec_positions,
|
| 470 |
+
kv_positions=state.enc_positions,
|
| 471 |
+
attn_mask=state.dec_cross_attn_mask,
|
|
|
|
| 472 |
cache=cross_attn_cache,
|
| 473 |
)
|
| 474 |
x = residual + ca_out
|
| 475 |
|
|
|
|
| 476 |
residual = x
|
| 477 |
x_norm = self.pre_mlp_norm(x)
|
| 478 |
+
mlp_out = self.mlp(x_norm)
|
| 479 |
x = residual + mlp_out
|
| 480 |
|
| 481 |
+
return x
|
| 482 |
|
| 483 |
|
| 484 |
class Decoder(nn.Module):
|
| 485 |
"""Transformer Decoder Stack using DenseGeneral."""
|
| 486 |
|
| 487 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
| 488 |
super().__init__()
|
| 489 |
self.config = config
|
| 490 |
model_config = config.model
|
| 491 |
dec_config = config.model.decoder
|
|
|
|
| 492 |
data_config = config.data
|
|
|
|
|
|
|
| 493 |
self.num_channels = data_config.channels
|
| 494 |
self.num_layers = dec_config.n_layer
|
| 495 |
|
| 496 |
self.embeddings = nn.ModuleList(
|
| 497 |
[
|
| 498 |
+
nn.Embedding(
|
| 499 |
+
model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
|
| 500 |
+
)
|
| 501 |
for _ in range(self.num_channels)
|
| 502 |
]
|
| 503 |
)
|
| 504 |
+
self.layers = nn.ModuleList(
|
| 505 |
+
[
|
| 506 |
+
DecoderLayer(config=config, compute_dtype=compute_dtype)
|
| 507 |
+
for _ in range(self.num_layers)
|
| 508 |
+
]
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
self.norm = RMSNorm(
|
| 512 |
dec_config.n_embd,
|
| 513 |
eps=model_config.normalization_layer_epsilon,
|
| 514 |
dtype=torch.float32,
|
| 515 |
)
|
| 516 |
|
|
|
|
| 517 |
self.logits_dense = DenseGeneral(
|
| 518 |
in_shapes=(dec_config.n_embd,),
|
| 519 |
out_features=(self.num_channels, model_config.tgt_vocab_size),
|
| 520 |
axis=(-1,),
|
| 521 |
+
weight_dtype=compute_dtype,
|
|
|
|
| 522 |
)
|
|
|
|
| 523 |
|
| 524 |
+
def precompute_cross_attn_cache(
|
| 525 |
self,
|
| 526 |
+
enc_out: torch.Tensor, # (B, S, E)
|
| 527 |
+
enc_positions: torch.Tensor, # (B, S)
|
|
|
|
| 528 |
) -> list[KVCache]:
|
| 529 |
"""
|
| 530 |
Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
|
|
|
|
| 533 |
|
| 534 |
for layer in self.layers:
|
| 535 |
cross_attn_module = layer.cross_attention
|
| 536 |
+
k_proj = cross_attn_module.k_proj(enc_out)
|
| 537 |
+
v_proj = cross_attn_module.v_proj(enc_out)
|
| 538 |
|
| 539 |
+
k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
|
| 540 |
k = k_proj.transpose(1, 2)
|
| 541 |
v = v_proj.transpose(1, 2)
|
| 542 |
|
| 543 |
+
per_layer_kv_cache.append(KVCache.from_kv(k, v))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
|
| 545 |
return per_layer_kv_cache
|
| 546 |
|
| 547 |
def decode_step(
|
| 548 |
self,
|
| 549 |
tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
|
| 550 |
+
state: DecoderInferenceState,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
) -> torch.Tensor:
|
| 552 |
"""
|
| 553 |
Performs a single decoding step, managing KV caches layer by layer.
|
|
|
|
| 556 |
A tuple containing:
|
| 557 |
- logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
|
| 558 |
"""
|
|
|
|
| 559 |
|
| 560 |
x = None
|
| 561 |
for i in range(self.num_channels):
|
|
|
|
| 563 |
channel_embed = self.embeddings[i](channel_tokens)
|
| 564 |
x = channel_embed if x is None else x + channel_embed
|
| 565 |
|
|
|
|
|
|
|
| 566 |
for i, layer in enumerate(self.layers):
|
| 567 |
+
self_cache = state.self_attn_cache[i]
|
| 568 |
+
cross_cache = state.cross_attn_cache[i]
|
| 569 |
+
x = layer(
|
| 570 |
x, # (2, 1, D)
|
| 571 |
+
state,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
self_attn_cache=self_cache,
|
| 573 |
cross_attn_cache=cross_cache,
|
| 574 |
)
|
|
|
|
| 575 |
|
| 576 |
x = self.norm(x)
|
| 577 |
logits_Bx1xCxV = self.logits_dense(x)
|
| 578 |
|
| 579 |
+
return logits_Bx1xCxV.to(torch.float32)
|
| 580 |
|
| 581 |
def forward(
|
| 582 |
+
self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
) -> torch.Tensor:
|
| 584 |
"""
|
| 585 |
Forward pass for the Decoder stack, managing KV caches.
|
|
|
|
| 589 |
encoder_out: Output from the encoder (B, S, E).
|
| 590 |
tgt_positions: Positions for target sequence (B, T).
|
| 591 |
src_positions: Positions for source sequence (B, S).
|
|
|
|
| 592 |
self_attn_mask: Mask for self-attention.
|
| 593 |
cross_attn_mask: Mask for cross-attention.
|
| 594 |
past_key_values: List containing the self-attention KV cache for each layer
|
|
|
|
| 614 |
channel_embed = self.embeddings[i](channel_tokens)
|
| 615 |
x = channel_embed if x is None else x + channel_embed
|
| 616 |
|
|
|
|
|
|
|
|
|
|
| 617 |
for i, layer in enumerate(self.layers):
|
| 618 |
+
self_cache = state.self_attn_cache[i]
|
| 619 |
+
cross_cache = state.cross_attn_cache[i]
|
| 620 |
+
x = layer(
|
| 621 |
x,
|
| 622 |
+
state,
|
| 623 |
+
self_attn_cache=self_cache,
|
| 624 |
+
cross_attn_cache=cross_cache,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
prefill=True,
|
| 626 |
)
|
| 627 |
|
|
|
|
| 635 |
class DiaModel(nn.Module):
|
| 636 |
"""PyTorch Dia Model using DenseGeneral."""
|
| 637 |
|
| 638 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
| 639 |
super().__init__()
|
| 640 |
self.config = config
|
| 641 |
+
self.encoder = Encoder(config, compute_dtype)
|
| 642 |
+
self.decoder = Decoder(config, compute_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dia/model.py
CHANGED
|
@@ -1,26 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import dac
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
import torchaudio
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
|
| 7 |
-
from .audio import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from .config import DiaConfig
|
| 9 |
-
from .layers import DiaModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def _sample_next_token(
|
| 13 |
logits_BCxV: torch.Tensor,
|
| 14 |
temperature: float,
|
| 15 |
top_p: float,
|
| 16 |
-
use_cfg_filter: bool,
|
| 17 |
cfg_filter_top_k: int | None = None,
|
| 18 |
) -> torch.Tensor:
|
| 19 |
if temperature == 0.0:
|
| 20 |
return torch.argmax(logits_BCxV, dim=-1)
|
| 21 |
|
| 22 |
logits_BCxV = logits_BCxV / temperature
|
| 23 |
-
if
|
| 24 |
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
|
| 25 |
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
|
| 26 |
mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
|
|
@@ -28,17 +48,21 @@ def _sample_next_token(
|
|
| 28 |
|
| 29 |
if top_p < 1.0:
|
| 30 |
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
| 31 |
-
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
|
|
|
|
|
|
|
| 32 |
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
|
| 33 |
|
| 34 |
-
# Calculate indices to remove based on top_p
|
| 35 |
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
|
| 41 |
-
indices_to_remove_BCxV.scatter_(
|
|
|
|
|
|
|
| 42 |
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
|
| 43 |
|
| 44 |
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
|
@@ -48,31 +72,61 @@ def _sample_next_token(
|
|
| 48 |
return sampled_indices_C
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
class Dia:
|
| 52 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
"""Initializes the Dia model.
|
| 54 |
|
| 55 |
Args:
|
| 56 |
config: The configuration object for the model.
|
| 57 |
-
device: The device to load the model onto.
|
| 58 |
|
| 59 |
Raises:
|
| 60 |
RuntimeError: If there is an error loading the DAC model.
|
| 61 |
"""
|
| 62 |
super().__init__()
|
| 63 |
self.config = config
|
| 64 |
-
self.device = device
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
| 66 |
self.dac_model = None
|
| 67 |
|
| 68 |
@classmethod
|
| 69 |
-
def from_local(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
"""Loads the Dia model from local configuration and checkpoint files.
|
| 71 |
|
| 72 |
Args:
|
| 73 |
config_path: Path to the configuration JSON file.
|
| 74 |
checkpoint_path: Path to the model checkpoint (.pth) file.
|
| 75 |
-
device: The device to load the model onto.
|
| 76 |
|
| 77 |
Returns:
|
| 78 |
An instance of the Dia model loaded with weights and set to eval mode.
|
|
@@ -85,23 +139,29 @@ class Dia:
|
|
| 85 |
if config is None:
|
| 86 |
raise FileNotFoundError(f"Config file not found at {config_path}")
|
| 87 |
|
| 88 |
-
dia = cls(config, device)
|
| 89 |
|
| 90 |
try:
|
| 91 |
-
|
|
|
|
| 92 |
except FileNotFoundError:
|
| 93 |
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
| 94 |
except Exception as e:
|
| 95 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
dia.model.to(device)
|
| 98 |
dia.model.eval()
|
| 99 |
dia._load_dac_model()
|
| 100 |
return dia
|
| 101 |
|
| 102 |
@classmethod
|
| 103 |
def from_pretrained(
|
| 104 |
-
cls,
|
|
|
|
|
|
|
|
|
|
| 105 |
) -> "Dia":
|
| 106 |
"""Loads the Dia model from a Hugging Face Hub repository.
|
| 107 |
|
|
@@ -110,7 +170,7 @@ class Dia:
|
|
| 110 |
|
| 111 |
Args:
|
| 112 |
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
|
| 113 |
-
device: The device to load the model onto.
|
| 114 |
|
| 115 |
Returns:
|
| 116 |
An instance of the Dia model loaded with weights and set to eval mode.
|
|
@@ -121,7 +181,7 @@ class Dia:
|
|
| 121 |
"""
|
| 122 |
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
|
| 123 |
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
|
| 124 |
-
return cls.from_local(config_path, checkpoint_path, device)
|
| 125 |
|
| 126 |
def _load_dac_model(self):
|
| 127 |
try:
|
|
@@ -131,44 +191,7 @@ class Dia:
|
|
| 131 |
raise RuntimeError("Failed to load DAC model") from e
|
| 132 |
self.dac_model = dac_model
|
| 133 |
|
| 134 |
-
def
|
| 135 |
-
self,
|
| 136 |
-
q_padding_mask_1d: torch.Tensor,
|
| 137 |
-
k_padding_mask_1d: torch.Tensor,
|
| 138 |
-
is_causal: bool = False,
|
| 139 |
-
) -> torch.Tensor:
|
| 140 |
-
"""
|
| 141 |
-
Creates the attention mask (self or cross) mimicking JAX segment ID logic.
|
| 142 |
-
"""
|
| 143 |
-
B1, Tq = q_padding_mask_1d.shape
|
| 144 |
-
B2, Tk = k_padding_mask_1d.shape
|
| 145 |
-
assert B1 == B2, "Query and key batch dimensions must match"
|
| 146 |
-
|
| 147 |
-
p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
|
| 148 |
-
p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
|
| 149 |
-
|
| 150 |
-
# Condition A: Non-padding query attends to non-padding key
|
| 151 |
-
non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
|
| 152 |
-
|
| 153 |
-
# Condition B: Padding query attends to padding key
|
| 154 |
-
pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
|
| 155 |
-
|
| 156 |
-
# Combine: True if padding status is compatible (both non-pad OR both pad)
|
| 157 |
-
# This implementation follows Jax TPU splash attention kernel
|
| 158 |
-
mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
|
| 159 |
-
|
| 160 |
-
if is_causal:
|
| 161 |
-
# Ensure causality for self-attention (Tq == Tk)
|
| 162 |
-
assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
|
| 163 |
-
# Standard lower-triangular causal mask (True means allow)
|
| 164 |
-
causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device)) # Shape [Tq, Tk]
|
| 165 |
-
causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
|
| 166 |
-
return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
|
| 167 |
-
else:
|
| 168 |
-
# For cross-attention or non-causal self-attention
|
| 169 |
-
return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
|
| 170 |
-
|
| 171 |
-
def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 172 |
"""Encodes text prompt, pads, and creates attention mask and positions."""
|
| 173 |
text_pad_value = self.config.data.text_pad_value
|
| 174 |
max_len = self.config.data.text_length
|
|
@@ -190,14 +213,168 @@ class Dia:
|
|
| 190 |
constant_values=text_pad_value,
|
| 191 |
).astype(np.uint8)
|
| 192 |
|
| 193 |
-
src_tokens =
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
@torch.inference_mode()
|
| 203 |
def generate(
|
|
@@ -207,225 +384,105 @@ class Dia:
|
|
| 207 |
cfg_scale: float = 3.0,
|
| 208 |
temperature: float = 1.3,
|
| 209 |
top_p: float = 0.95,
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
audio_prompt_path: str | None = None,
|
|
|
|
|
|
|
| 214 |
) -> np.ndarray:
|
| 215 |
-
"""
|
| 216 |
-
Generates audio from a text prompt (and optional audio prompt) using the Nari model.
|
| 217 |
-
|
| 218 |
-
Returns:
|
| 219 |
-
A tensor of generated audio codes (shape: [max_tokens, num_channels]).
|
| 220 |
-
"""
|
| 221 |
-
num_channels = self.config.data.channels
|
| 222 |
-
audio_bos_value = self.config.data.audio_bos_value
|
| 223 |
audio_eos_value = self.config.data.audio_eos_value
|
| 224 |
audio_pad_value = self.config.data.audio_pad_value
|
| 225 |
delay_pattern = self.config.data.delay_pattern
|
| 226 |
max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
|
| 227 |
-
delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device)
|
| 228 |
max_delay_pattern = max(delay_pattern)
|
| 229 |
self.model.eval()
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
) = self._prepare_text_input(text)
|
| 237 |
-
|
| 238 |
-
unc_src_BxS = torch.zeros_like(cond_src_BxS)
|
| 239 |
-
src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
|
| 240 |
-
src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
|
| 241 |
-
src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1)
|
| 242 |
-
enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1)
|
| 243 |
-
|
| 244 |
-
# 2. Encoder Pass
|
| 245 |
-
# with torch.autocast(device_type="cuda", dtype=forward_dtype):
|
| 246 |
-
encoder_out = self.model.encoder(
|
| 247 |
-
x_ids=src_BxS,
|
| 248 |
-
src_positions=src_positions_BxS,
|
| 249 |
-
deterministic=True,
|
| 250 |
-
attn_mask=enc_self_attn_mask_Bx1xSxS,
|
| 251 |
-
) # Shape: (B, S, E)
|
| 252 |
-
|
| 253 |
-
# 3. Prepare Decoder Inputs
|
| 254 |
-
# 3-1. Allocate KV Cache (Static)
|
| 255 |
-
decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv(
|
| 256 |
-
max_tokens, encoder_out, src_positions_BxS
|
| 257 |
-
)
|
| 258 |
-
|
| 259 |
-
decoder_self_attention_cache: list[KVCache] = []
|
| 260 |
-
for _ in range(self.model.decoder.num_layers):
|
| 261 |
-
decoder_self_attention_cache.append(
|
| 262 |
-
KVCache(
|
| 263 |
-
self.config.model.decoder.gqa_query_heads,
|
| 264 |
-
max_tokens,
|
| 265 |
-
self.config.model.decoder.gqa_head_dim,
|
| 266 |
-
self.device,
|
| 267 |
-
)
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
# 3-2. Initialize Decoder Inputs
|
| 271 |
-
generated_BxTxC = torch.full(
|
| 272 |
-
(2, 1, num_channels),
|
| 273 |
-
fill_value=audio_bos_value,
|
| 274 |
-
dtype=torch.long,
|
| 275 |
-
device=self.device,
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
current_step = 0
|
| 279 |
-
prompt_len_inc_bos = 1 # Start with BOS length
|
| 280 |
-
|
| 281 |
-
# 3-3. Load Audio Prompt (if provided)
|
| 282 |
-
if audio_prompt_path is not None:
|
| 283 |
-
audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T
|
| 284 |
-
if sr != 44100: # Resample to 44.1kHz
|
| 285 |
-
audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100)
|
| 286 |
-
audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T
|
| 287 |
-
audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data)
|
| 288 |
-
generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1)
|
| 289 |
-
|
| 290 |
-
prefill_len = generated_BxTxC.shape[1]
|
| 291 |
-
prompt_len_inc_bos = prefill_len
|
| 292 |
-
prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1)
|
| 293 |
-
prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2)
|
| 294 |
-
|
| 295 |
-
prefill_self_attn_mask = self._create_attn_mask(
|
| 296 |
-
prefill_tgt_padding_mask,
|
| 297 |
-
prefill_tgt_padding_mask,
|
| 298 |
-
is_causal=True,
|
| 299 |
-
)
|
| 300 |
-
prefill_cross_attn_mask = self._create_attn_mask(
|
| 301 |
-
prefill_tgt_padding_mask,
|
| 302 |
-
src_padding_mask_BxS,
|
| 303 |
-
is_causal=False,
|
| 304 |
-
)
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
encoder_out=encoder_out,
|
| 309 |
-
tgt_positions=prefill_tgt_pos,
|
| 310 |
-
src_positions=src_positions_BxS,
|
| 311 |
-
deterministic=True,
|
| 312 |
-
self_attn_mask=prefill_self_attn_mask,
|
| 313 |
-
cross_attn_mask=prefill_cross_attn_mask,
|
| 314 |
-
self_attention_cache=decoder_self_attention_cache,
|
| 315 |
-
cross_attention_cache=decoder_cross_attention_cache,
|
| 316 |
-
)
|
| 317 |
|
| 318 |
-
|
|
|
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
eos_countdown = -1
|
| 323 |
-
extra_steps_after_eos = 30
|
| 324 |
-
# Make generated_BxTxC a fixed size tensor
|
| 325 |
-
# Length is either 1 + max tokens or 1 + prompt len + max tokens
|
| 326 |
-
generated_BxTxC = torch.cat(
|
| 327 |
-
[
|
| 328 |
-
generated_BxTxC,
|
| 329 |
-
torch.full(
|
| 330 |
-
(2, max_tokens, num_channels),
|
| 331 |
-
fill_value=-1,
|
| 332 |
-
dtype=torch.long,
|
| 333 |
-
device=self.device,
|
| 334 |
-
),
|
| 335 |
-
],
|
| 336 |
-
dim=1,
|
| 337 |
-
)
|
| 338 |
|
| 339 |
-
decode_step = self.model.decoder.decode_step
|
| 340 |
if use_torch_compile:
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
)
|
| 345 |
|
| 346 |
-
|
| 347 |
-
(
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
is_causal=False,
|
| 354 |
-
) # [B, 1, 1, S]
|
| 355 |
-
|
| 356 |
-
for step in range(current_step, current_step + max_tokens):
|
| 357 |
-
tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1)
|
| 358 |
-
tgt_pos_Bx1 = torch.full(
|
| 359 |
-
(2, 1),
|
| 360 |
-
fill_value=step,
|
| 361 |
-
dtype=torch.long,
|
| 362 |
-
device=self.device,
|
| 363 |
-
)
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
self_attn_mask=None,
|
| 370 |
-
cross_attn_mask=decoder_cross_attn_mask,
|
| 371 |
-
self_attention_cache=decoder_self_attention_cache,
|
| 372 |
-
cross_attention_cache=decoder_cross_attention_cache,
|
| 373 |
)
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
cond_logits_CxV = logits_last_BxCxV[1, :, :]
|
| 382 |
-
|
| 383 |
-
cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
|
| 384 |
-
|
| 385 |
-
logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V
|
| 386 |
-
logits_CxV[:, 1025:] = -torch.inf
|
| 387 |
-
|
| 388 |
-
# Sample next token
|
| 389 |
-
pred_C = _sample_next_token(
|
| 390 |
-
logits_CxV.float(),
|
| 391 |
-
temperature=temperature,
|
| 392 |
-
top_p=top_p,
|
| 393 |
-
use_cfg_filter=use_cfg_filter,
|
| 394 |
-
cfg_filter_top_k=cfg_filter_top_k,
|
| 395 |
)
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
audio_bos_value,
|
| 403 |
-
)
|
| 404 |
-
|
| 405 |
-
generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1)
|
| 406 |
-
|
| 407 |
-
if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
|
| 408 |
-
eos_detected_channel_0 = True
|
| 409 |
-
eos_countdown = extra_steps_after_eos
|
| 410 |
|
| 411 |
if eos_countdown > 0:
|
| 412 |
step_after_eos = max_delay_pattern - eos_countdown
|
| 413 |
for i, d in enumerate(delay_pattern):
|
| 414 |
if step_after_eos == d:
|
| 415 |
-
|
| 416 |
elif step_after_eos > d:
|
| 417 |
-
|
| 418 |
eos_countdown -= 1
|
| 419 |
-
if eos_countdown == 0:
|
| 420 |
-
break
|
| 421 |
|
| 422 |
-
|
|
|
|
| 423 |
|
| 424 |
-
|
|
|
|
| 425 |
|
| 426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
import dac
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
import torchaudio
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
| 10 |
+
from .audio import (
|
| 11 |
+
apply_audio_delay,
|
| 12 |
+
build_delay_indices,
|
| 13 |
+
build_revert_indices,
|
| 14 |
+
decode,
|
| 15 |
+
revert_audio_delay,
|
| 16 |
+
)
|
| 17 |
from .config import DiaConfig
|
| 18 |
+
from .layers import DiaModel
|
| 19 |
+
from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
DEFAULT_SAMPLE_RATE = 44100
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_default_device():
|
| 26 |
+
if torch.cuda.is_available():
|
| 27 |
+
return torch.device("cuda")
|
| 28 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 29 |
+
return torch.device("mps")
|
| 30 |
+
return torch.device("cpu")
|
| 31 |
|
| 32 |
|
| 33 |
def _sample_next_token(
|
| 34 |
logits_BCxV: torch.Tensor,
|
| 35 |
temperature: float,
|
| 36 |
top_p: float,
|
|
|
|
| 37 |
cfg_filter_top_k: int | None = None,
|
| 38 |
) -> torch.Tensor:
|
| 39 |
if temperature == 0.0:
|
| 40 |
return torch.argmax(logits_BCxV, dim=-1)
|
| 41 |
|
| 42 |
logits_BCxV = logits_BCxV / temperature
|
| 43 |
+
if cfg_filter_top_k is not None:
|
| 44 |
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
|
| 45 |
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
|
| 46 |
mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
|
|
|
|
| 48 |
|
| 49 |
if top_p < 1.0:
|
| 50 |
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
| 51 |
+
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
|
| 52 |
+
probs_BCxV, dim=-1, descending=True
|
| 53 |
+
)
|
| 54 |
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
|
| 55 |
|
|
|
|
| 56 |
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
|
| 57 |
+
sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
|
| 58 |
+
..., :-1
|
| 59 |
+
].clone()
|
| 60 |
+
sorted_indices_to_remove_BCxV[..., 0] = 0
|
| 61 |
|
| 62 |
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
|
| 63 |
+
indices_to_remove_BCxV.scatter_(
|
| 64 |
+
dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
|
| 65 |
+
)
|
| 66 |
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
|
| 67 |
|
| 68 |
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
|
|
|
| 72 |
return sampled_indices_C
|
| 73 |
|
| 74 |
|
| 75 |
+
class ComputeDtype(str, Enum):
|
| 76 |
+
FLOAT32 = "float32"
|
| 77 |
+
FLOAT16 = "float16"
|
| 78 |
+
BFLOAT16 = "bfloat16"
|
| 79 |
+
|
| 80 |
+
def to_dtype(self) -> torch.dtype:
|
| 81 |
+
if self == ComputeDtype.FLOAT32:
|
| 82 |
+
return torch.float32
|
| 83 |
+
elif self == ComputeDtype.FLOAT16:
|
| 84 |
+
return torch.float16
|
| 85 |
+
elif self == ComputeDtype.BFLOAT16:
|
| 86 |
+
return torch.bfloat16
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(f"Unsupported compute dtype: {self}")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
class Dia:
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
config: DiaConfig,
|
| 95 |
+
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
|
| 96 |
+
device: torch.device | None = None,
|
| 97 |
+
):
|
| 98 |
"""Initializes the Dia model.
|
| 99 |
|
| 100 |
Args:
|
| 101 |
config: The configuration object for the model.
|
| 102 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
| 103 |
|
| 104 |
Raises:
|
| 105 |
RuntimeError: If there is an error loading the DAC model.
|
| 106 |
"""
|
| 107 |
super().__init__()
|
| 108 |
self.config = config
|
| 109 |
+
self.device = device if device is not None else _get_default_device()
|
| 110 |
+
if isinstance(compute_dtype, str):
|
| 111 |
+
compute_dtype = ComputeDtype(compute_dtype)
|
| 112 |
+
self.compute_dtype = compute_dtype.to_dtype()
|
| 113 |
+
self.model = DiaModel(config, self.compute_dtype)
|
| 114 |
self.dac_model = None
|
| 115 |
|
| 116 |
@classmethod
|
| 117 |
+
def from_local(
|
| 118 |
+
cls,
|
| 119 |
+
config_path: str,
|
| 120 |
+
checkpoint_path: str,
|
| 121 |
+
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
|
| 122 |
+
device: torch.device | None = None,
|
| 123 |
+
) -> "Dia":
|
| 124 |
"""Loads the Dia model from local configuration and checkpoint files.
|
| 125 |
|
| 126 |
Args:
|
| 127 |
config_path: Path to the configuration JSON file.
|
| 128 |
checkpoint_path: Path to the model checkpoint (.pth) file.
|
| 129 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
| 130 |
|
| 131 |
Returns:
|
| 132 |
An instance of the Dia model loaded with weights and set to eval mode.
|
|
|
|
| 139 |
if config is None:
|
| 140 |
raise FileNotFoundError(f"Config file not found at {config_path}")
|
| 141 |
|
| 142 |
+
dia = cls(config, compute_dtype, device)
|
| 143 |
|
| 144 |
try:
|
| 145 |
+
state_dict = torch.load(checkpoint_path, map_location=dia.device)
|
| 146 |
+
dia.model.load_state_dict(state_dict)
|
| 147 |
except FileNotFoundError:
|
| 148 |
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
| 149 |
except Exception as e:
|
| 150 |
+
raise RuntimeError(
|
| 151 |
+
f"Error loading checkpoint from {checkpoint_path}"
|
| 152 |
+
) from e
|
| 153 |
|
| 154 |
+
dia.model.to(dia.device)
|
| 155 |
dia.model.eval()
|
| 156 |
dia._load_dac_model()
|
| 157 |
return dia
|
| 158 |
|
| 159 |
@classmethod
|
| 160 |
def from_pretrained(
|
| 161 |
+
cls,
|
| 162 |
+
model_name: str = "nari-labs/Dia-1.6B",
|
| 163 |
+
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
|
| 164 |
+
device: torch.device | None = None,
|
| 165 |
) -> "Dia":
|
| 166 |
"""Loads the Dia model from a Hugging Face Hub repository.
|
| 167 |
|
|
|
|
| 170 |
|
| 171 |
Args:
|
| 172 |
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
|
| 173 |
+
device: The device to load the model onto. If None, will automatically select the best available device.
|
| 174 |
|
| 175 |
Returns:
|
| 176 |
An instance of the Dia model loaded with weights and set to eval mode.
|
|
|
|
| 181 |
"""
|
| 182 |
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
|
| 183 |
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
|
| 184 |
+
return cls.from_local(config_path, checkpoint_path, compute_dtype, device)
|
| 185 |
|
| 186 |
def _load_dac_model(self):
|
| 187 |
try:
|
|
|
|
| 191 |
raise RuntimeError("Failed to load DAC model") from e
|
| 192 |
self.dac_model = dac_model
|
| 193 |
|
| 194 |
+
def _prepare_text_input(self, text: str) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
"""Encodes text prompt, pads, and creates attention mask and positions."""
|
| 196 |
text_pad_value = self.config.data.text_pad_value
|
| 197 |
max_len = self.config.data.text_length
|
|
|
|
| 213 |
constant_values=text_pad_value,
|
| 214 |
).astype(np.uint8)
|
| 215 |
|
| 216 |
+
src_tokens = (
|
| 217 |
+
torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0)
|
| 218 |
+
) # [1, S]
|
| 219 |
+
return src_tokens
|
| 220 |
|
| 221 |
+
def _prepare_audio_prompt(
|
| 222 |
+
self, audio_prompt: torch.Tensor | None
|
| 223 |
+
) -> tuple[torch.Tensor, int]:
|
| 224 |
+
num_channels = self.config.data.channels
|
| 225 |
+
audio_bos_value = self.config.data.audio_bos_value
|
| 226 |
+
audio_pad_value = self.config.data.audio_pad_value
|
| 227 |
+
delay_pattern = self.config.data.delay_pattern
|
| 228 |
+
max_delay_pattern = max(delay_pattern)
|
| 229 |
|
| 230 |
+
prefill = torch.full(
|
| 231 |
+
(1, num_channels),
|
| 232 |
+
fill_value=audio_bos_value,
|
| 233 |
+
dtype=torch.int,
|
| 234 |
+
device=self.device,
|
| 235 |
+
)
|
| 236 |
|
| 237 |
+
prefill_step = 1
|
| 238 |
+
|
| 239 |
+
if audio_prompt is not None:
|
| 240 |
+
prefill_step += audio_prompt.shape[0]
|
| 241 |
+
prefill = torch.cat([prefill, audio_prompt], dim=0)
|
| 242 |
+
|
| 243 |
+
delay_pad_tensor = torch.full(
|
| 244 |
+
(max_delay_pattern, num_channels),
|
| 245 |
+
fill_value=-1,
|
| 246 |
+
dtype=torch.int,
|
| 247 |
+
device=self.device,
|
| 248 |
+
)
|
| 249 |
+
prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
|
| 250 |
+
|
| 251 |
+
delay_precomp = build_delay_indices(
|
| 252 |
+
B=1,
|
| 253 |
+
T=prefill.shape[0],
|
| 254 |
+
C=num_channels,
|
| 255 |
+
delay_pattern=delay_pattern,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
prefill = apply_audio_delay(
|
| 259 |
+
audio_BxTxC=prefill.unsqueeze(0),
|
| 260 |
+
pad_value=audio_pad_value,
|
| 261 |
+
bos_value=audio_bos_value,
|
| 262 |
+
precomp=delay_precomp,
|
| 263 |
+
).squeeze(0)
|
| 264 |
+
|
| 265 |
+
return prefill, prefill_step
|
| 266 |
+
|
| 267 |
+
def _prepare_generation(
|
| 268 |
+
self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool
|
| 269 |
+
):
|
| 270 |
+
enc_input_cond = self._prepare_text_input(text)
|
| 271 |
+
enc_input_uncond = torch.zeros_like(enc_input_cond)
|
| 272 |
+
enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
|
| 273 |
+
|
| 274 |
+
if isinstance(audio_prompt, str):
|
| 275 |
+
audio_prompt = self.load_audio(audio_prompt)
|
| 276 |
+
prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
|
| 277 |
+
|
| 278 |
+
if verbose:
|
| 279 |
+
print("generate: data loaded")
|
| 280 |
+
|
| 281 |
+
enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
|
| 282 |
+
encoder_out = self.model.encoder(enc_input, enc_state)
|
| 283 |
+
|
| 284 |
+
dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
|
| 285 |
+
encoder_out, enc_state.positions
|
| 286 |
+
)
|
| 287 |
+
dec_state = DecoderInferenceState.new(
|
| 288 |
+
self.config,
|
| 289 |
+
enc_state,
|
| 290 |
+
encoder_out,
|
| 291 |
+
dec_cross_attn_cache,
|
| 292 |
+
self.compute_dtype,
|
| 293 |
+
)
|
| 294 |
+
dec_output = DecoderOutput.new(self.config, self.device)
|
| 295 |
+
dec_output.prefill(prefill, prefill_step)
|
| 296 |
+
|
| 297 |
+
dec_step = prefill_step - 1
|
| 298 |
+
if dec_step > 0:
|
| 299 |
+
dec_state.prepare_step(0, dec_step)
|
| 300 |
+
tokens_BxTxC = (
|
| 301 |
+
dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
|
| 302 |
+
)
|
| 303 |
+
self.model.decoder.forward(tokens_BxTxC, dec_state)
|
| 304 |
+
|
| 305 |
+
return dec_state, dec_output
|
| 306 |
+
|
| 307 |
+
def _decoder_step(
|
| 308 |
+
self,
|
| 309 |
+
tokens_Bx1xC: torch.Tensor,
|
| 310 |
+
dec_state: DecoderInferenceState,
|
| 311 |
+
cfg_scale: float,
|
| 312 |
+
temperature: float,
|
| 313 |
+
top_p: float,
|
| 314 |
+
cfg_filter_top_k: int,
|
| 315 |
+
) -> torch.Tensor:
|
| 316 |
+
audio_eos_value = self.config.data.audio_eos_value
|
| 317 |
+
logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
|
| 318 |
+
|
| 319 |
+
logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
|
| 320 |
+
uncond_logits_CxV = logits_last_BxCxV[0, :, :]
|
| 321 |
+
cond_logits_CxV = logits_last_BxCxV[1, :, :]
|
| 322 |
+
|
| 323 |
+
logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
|
| 324 |
+
logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
|
| 325 |
+
logits_CxV[1:, audio_eos_value:] = -torch.inf
|
| 326 |
+
|
| 327 |
+
pred_C = _sample_next_token(
|
| 328 |
+
logits_CxV.float(),
|
| 329 |
+
temperature=temperature,
|
| 330 |
+
top_p=top_p,
|
| 331 |
+
cfg_filter_top_k=cfg_filter_top_k,
|
| 332 |
+
)
|
| 333 |
+
return pred_C
|
| 334 |
+
|
| 335 |
+
def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
|
| 336 |
+
num_channels = self.config.data.channels
|
| 337 |
+
seq_length = generated_codes.shape[0]
|
| 338 |
+
delay_pattern = self.config.data.delay_pattern
|
| 339 |
+
audio_pad_value = self.config.data.audio_pad_value
|
| 340 |
+
max_delay_pattern = max(delay_pattern)
|
| 341 |
+
|
| 342 |
+
revert_precomp = build_revert_indices(
|
| 343 |
+
B=1,
|
| 344 |
+
T=seq_length,
|
| 345 |
+
C=num_channels,
|
| 346 |
+
delay_pattern=delay_pattern,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
codebook = revert_audio_delay(
|
| 350 |
+
audio_BxTxC=generated_codes.unsqueeze(0),
|
| 351 |
+
pad_value=audio_pad_value,
|
| 352 |
+
precomp=revert_precomp,
|
| 353 |
+
T=seq_length,
|
| 354 |
+
)[:, :-max_delay_pattern, :]
|
| 355 |
+
|
| 356 |
+
min_valid_index = 0
|
| 357 |
+
max_valid_index = 1023
|
| 358 |
+
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
|
| 359 |
+
codebook[invalid_mask] = 0
|
| 360 |
+
|
| 361 |
+
audio = decode(self.dac_model, codebook.transpose(1, 2))
|
| 362 |
+
|
| 363 |
+
return audio.squeeze().cpu().numpy()
|
| 364 |
+
|
| 365 |
+
def load_audio(self, audio_path: str) -> torch.Tensor:
|
| 366 |
+
audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
|
| 367 |
+
if sr != DEFAULT_SAMPLE_RATE:
|
| 368 |
+
audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
|
| 369 |
+
audio = audio.to(self.device).unsqueeze(0) # 1, C, T
|
| 370 |
+
audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
|
| 371 |
+
_, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T
|
| 372 |
+
return encoded_frame.squeeze(0).transpose(0, 1)
|
| 373 |
+
|
| 374 |
+
def save_audio(self, path: str, audio: np.ndarray):
|
| 375 |
+
import soundfile as sf
|
| 376 |
+
|
| 377 |
+
sf.write(path, audio, DEFAULT_SAMPLE_RATE)
|
| 378 |
|
| 379 |
@torch.inference_mode()
|
| 380 |
def generate(
|
|
|
|
| 384 |
cfg_scale: float = 3.0,
|
| 385 |
temperature: float = 1.3,
|
| 386 |
top_p: float = 0.95,
|
| 387 |
+
use_torch_compile: bool = False,
|
| 388 |
+
cfg_filter_top_k: int = 35,
|
| 389 |
+
audio_prompt: str | torch.Tensor | None = None,
|
| 390 |
audio_prompt_path: str | None = None,
|
| 391 |
+
use_cfg_filter: bool | None = None,
|
| 392 |
+
verbose: bool = False,
|
| 393 |
) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
audio_eos_value = self.config.data.audio_eos_value
|
| 395 |
audio_pad_value = self.config.data.audio_pad_value
|
| 396 |
delay_pattern = self.config.data.delay_pattern
|
| 397 |
max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
|
|
|
|
| 398 |
max_delay_pattern = max(delay_pattern)
|
| 399 |
self.model.eval()
|
| 400 |
|
| 401 |
+
if audio_prompt_path:
|
| 402 |
+
print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
|
| 403 |
+
audio_prompt = audio_prompt_path
|
| 404 |
+
if use_cfg_filter is not None:
|
| 405 |
+
print("Warning: use_cfg_filter is deprecated.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
+
if verbose:
|
| 408 |
+
total_start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
+
dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
|
| 411 |
+
dec_step = dec_output.prefill_step - 1
|
| 412 |
|
| 413 |
+
bos_countdown = max_delay_pattern
|
| 414 |
+
eos_detected = False
|
| 415 |
eos_countdown = -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
|
|
|
| 417 |
if use_torch_compile:
|
| 418 |
+
step_fn = torch.compile(self._decoder_step, mode="default")
|
| 419 |
+
else:
|
| 420 |
+
step_fn = self._decoder_step
|
|
|
|
| 421 |
|
| 422 |
+
if verbose:
|
| 423 |
+
print("generate: starting generation loop")
|
| 424 |
+
if use_torch_compile:
|
| 425 |
+
print(
|
| 426 |
+
"generate: by using use_torch_compile=True, the first step would take long"
|
| 427 |
+
)
|
| 428 |
+
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
+
while dec_step < max_tokens:
|
| 431 |
+
dec_state.prepare_step(dec_step)
|
| 432 |
+
tokens_Bx1xC = (
|
| 433 |
+
dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
)
|
| 435 |
+
pred_C = step_fn(
|
| 436 |
+
tokens_Bx1xC,
|
| 437 |
+
dec_state,
|
| 438 |
+
cfg_scale,
|
| 439 |
+
temperature,
|
| 440 |
+
top_p,
|
| 441 |
+
cfg_filter_top_k,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
)
|
| 443 |
|
| 444 |
+
if (
|
| 445 |
+
not eos_detected and pred_C[0] == audio_eos_value
|
| 446 |
+
) or dec_step == max_tokens - max_delay_pattern - 1:
|
| 447 |
+
eos_detected = True
|
| 448 |
+
eos_countdown = max_delay_pattern
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
if eos_countdown > 0:
|
| 451 |
step_after_eos = max_delay_pattern - eos_countdown
|
| 452 |
for i, d in enumerate(delay_pattern):
|
| 453 |
if step_after_eos == d:
|
| 454 |
+
pred_C[i] = audio_eos_value
|
| 455 |
elif step_after_eos > d:
|
| 456 |
+
pred_C[i] = audio_pad_value
|
| 457 |
eos_countdown -= 1
|
|
|
|
|
|
|
| 458 |
|
| 459 |
+
bos_countdown = max(0, bos_countdown - 1)
|
| 460 |
+
dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
|
| 461 |
|
| 462 |
+
if eos_countdown == 0:
|
| 463 |
+
break
|
| 464 |
|
| 465 |
+
dec_step += 1
|
| 466 |
+
if verbose and dec_step % 86 == 0:
|
| 467 |
+
duration = time.time() - start_time
|
| 468 |
+
print(
|
| 469 |
+
f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
|
| 470 |
+
)
|
| 471 |
+
start_time = time.time()
|
| 472 |
|
| 473 |
+
if dec_output.prefill_step >= dec_step + 1:
|
| 474 |
+
print("Warning: Nothing generated")
|
| 475 |
+
return None
|
| 476 |
+
|
| 477 |
+
generated_codes = dec_output.generated_tokens[
|
| 478 |
+
dec_output.prefill_step : dec_step + 1, :
|
| 479 |
+
]
|
| 480 |
+
|
| 481 |
+
if verbose:
|
| 482 |
+
total_step = dec_step + 1 - dec_output.prefill_step
|
| 483 |
+
total_duration = time.time() - total_start_time
|
| 484 |
+
print(
|
| 485 |
+
f"generate: total step={total_step}, total duration={total_duration:.3f}s"
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
return self._generate_output(generated_codes)
|
dia/state.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .config import DiaConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_attn_mask(
|
| 9 |
+
q_padding_mask_1d: torch.Tensor,
|
| 10 |
+
k_padding_mask_1d: torch.Tensor,
|
| 11 |
+
device: torch.device,
|
| 12 |
+
is_causal: bool = False,
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Creates the attention mask (self or cross) mimicking JAX segment ID logic.
|
| 16 |
+
"""
|
| 17 |
+
B1, Tq = q_padding_mask_1d.shape
|
| 18 |
+
B2, Tk = k_padding_mask_1d.shape
|
| 19 |
+
assert B1 == B2, "Query and key batch dimensions must match"
|
| 20 |
+
|
| 21 |
+
p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
|
| 22 |
+
p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
|
| 23 |
+
|
| 24 |
+
# Condition A: Non-padding query attends to non-padding key
|
| 25 |
+
non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
|
| 26 |
+
|
| 27 |
+
# Condition B: Padding query attends to padding key
|
| 28 |
+
pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
|
| 29 |
+
|
| 30 |
+
# Combine: True if padding status is compatible (both non-pad OR both pad)
|
| 31 |
+
mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
|
| 32 |
+
|
| 33 |
+
if is_causal:
|
| 34 |
+
assert Tq == Tk, (
|
| 35 |
+
"Causal mask requires query and key sequence lengths to be equal"
|
| 36 |
+
)
|
| 37 |
+
causal_mask_2d = torch.tril(
|
| 38 |
+
torch.ones((Tq, Tk), dtype=torch.bool, device=device)
|
| 39 |
+
) # Shape [Tq, Tk]
|
| 40 |
+
causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
|
| 41 |
+
return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
|
| 42 |
+
else:
|
| 43 |
+
return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class EncoderInferenceState:
|
| 48 |
+
"""Parameters specifically for encoder inference."""
|
| 49 |
+
|
| 50 |
+
max_seq_len: int
|
| 51 |
+
device: torch.device
|
| 52 |
+
positions: torch.Tensor
|
| 53 |
+
padding_mask: torch.Tensor
|
| 54 |
+
attn_mask: torch.Tensor
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
|
| 58 |
+
"""Creates EtorchrInferenceParams from DiaConfig and a device."""
|
| 59 |
+
device = cond_src.device
|
| 60 |
+
|
| 61 |
+
positions = (
|
| 62 |
+
torch.arange(config.data.text_length, device=device)
|
| 63 |
+
.to(torch.long)
|
| 64 |
+
.unsqueeze(0)
|
| 65 |
+
.expand(2, -1)
|
| 66 |
+
)
|
| 67 |
+
padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
|
| 68 |
+
attn_mask = create_attn_mask(
|
| 69 |
+
padding_mask, padding_mask, device, is_causal=False
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return cls(
|
| 73 |
+
max_seq_len=config.data.text_length,
|
| 74 |
+
device=device,
|
| 75 |
+
positions=positions,
|
| 76 |
+
padding_mask=padding_mask,
|
| 77 |
+
attn_mask=attn_mask,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class KVCache:
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
num_heads: int,
|
| 85 |
+
max_len: int,
|
| 86 |
+
head_dim: int,
|
| 87 |
+
dtype: torch.dtype,
|
| 88 |
+
device: torch.device,
|
| 89 |
+
k: torch.Tensor | None = None,
|
| 90 |
+
v: torch.Tensor | None = None,
|
| 91 |
+
):
|
| 92 |
+
self.k = (
|
| 93 |
+
torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
|
| 94 |
+
if k is None
|
| 95 |
+
else k
|
| 96 |
+
)
|
| 97 |
+
self.v = (
|
| 98 |
+
torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
|
| 99 |
+
if v is None
|
| 100 |
+
else v
|
| 101 |
+
)
|
| 102 |
+
self.current_idx = torch.tensor(0)
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
|
| 106 |
+
return cls(
|
| 107 |
+
num_heads=k.shape[1],
|
| 108 |
+
max_len=k.shape[2],
|
| 109 |
+
head_dim=k.shape[3],
|
| 110 |
+
dtype=k.dtype,
|
| 111 |
+
device=k.device,
|
| 112 |
+
k=k,
|
| 113 |
+
v=v,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def update(
|
| 117 |
+
self, k: torch.Tensor, v: torch.Tensor
|
| 118 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 119 |
+
self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
|
| 120 |
+
self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
|
| 121 |
+
self.current_idx += 1
|
| 122 |
+
return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
|
| 123 |
+
|
| 124 |
+
def prefill(
|
| 125 |
+
self, k: torch.Tensor, v: torch.Tensor
|
| 126 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 127 |
+
prefill_len = k.shape[2]
|
| 128 |
+
self.k[:, :, :prefill_len, :] = k
|
| 129 |
+
self.v[:, :, :prefill_len, :] = v
|
| 130 |
+
self.current_idx = prefill_len - 1
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@dataclass
|
| 134 |
+
class DecoderInferenceState:
|
| 135 |
+
"""Parameters specifically for decoder inference."""
|
| 136 |
+
|
| 137 |
+
device: torch.device
|
| 138 |
+
dtype: torch.dtype
|
| 139 |
+
enc_out: torch.Tensor
|
| 140 |
+
enc_positions: torch.Tensor
|
| 141 |
+
dec_positions: torch.Tensor
|
| 142 |
+
dec_cross_attn_mask: torch.Tensor
|
| 143 |
+
self_attn_cache: list[KVCache]
|
| 144 |
+
cross_attn_cache: list[KVCache]
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def new(
|
| 148 |
+
cls,
|
| 149 |
+
config: DiaConfig,
|
| 150 |
+
enc_state: EncoderInferenceState,
|
| 151 |
+
enc_out: torch.Tensor,
|
| 152 |
+
dec_cross_attn_cache: list[KVCache],
|
| 153 |
+
compute_dtype: torch.dtype,
|
| 154 |
+
) -> "DecoderInferenceState":
|
| 155 |
+
"""Creates DecoderInferenceParams from DiaConfig and a device."""
|
| 156 |
+
device = enc_out.device
|
| 157 |
+
max_audio_len = config.data.audio_length
|
| 158 |
+
|
| 159 |
+
dec_positions = torch.full(
|
| 160 |
+
(2, 1), fill_value=0, dtype=torch.long, device=device
|
| 161 |
+
)
|
| 162 |
+
tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device)
|
| 163 |
+
dec_cross_attn_mask = create_attn_mask(
|
| 164 |
+
tgt_padding_mask, enc_state.padding_mask, device, is_causal=False
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self_attn_cache = [
|
| 168 |
+
KVCache(
|
| 169 |
+
config.model.decoder.kv_heads,
|
| 170 |
+
max_audio_len,
|
| 171 |
+
config.model.decoder.gqa_head_dim,
|
| 172 |
+
compute_dtype,
|
| 173 |
+
device,
|
| 174 |
+
)
|
| 175 |
+
for _ in range(config.model.decoder.n_layer)
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
return cls(
|
| 179 |
+
device=device,
|
| 180 |
+
dtype=compute_dtype,
|
| 181 |
+
enc_out=enc_out,
|
| 182 |
+
enc_positions=enc_state.positions,
|
| 183 |
+
dec_positions=dec_positions,
|
| 184 |
+
dec_cross_attn_mask=dec_cross_attn_mask,
|
| 185 |
+
self_attn_cache=self_attn_cache,
|
| 186 |
+
cross_attn_cache=dec_cross_attn_cache,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
|
| 190 |
+
if step_to is None:
|
| 191 |
+
step_to = step_from + 1
|
| 192 |
+
self.dec_positions = (
|
| 193 |
+
torch.arange(step_from, step_to, device=self.device)
|
| 194 |
+
.unsqueeze(0)
|
| 195 |
+
.expand(2, -1)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@dataclass
|
| 200 |
+
class DecoderOutput:
|
| 201 |
+
generated_tokens: torch.Tensor
|
| 202 |
+
prefill_step: int
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput":
|
| 206 |
+
max_audio_len = config.data.audio_length
|
| 207 |
+
return cls(
|
| 208 |
+
generated_tokens=torch.full(
|
| 209 |
+
(max_audio_len, config.data.channels),
|
| 210 |
+
fill_value=-1,
|
| 211 |
+
dtype=torch.int,
|
| 212 |
+
device=device,
|
| 213 |
+
),
|
| 214 |
+
prefill_step=0,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
|
| 218 |
+
if step_to is None:
|
| 219 |
+
step_to = step_from + 1
|
| 220 |
+
return self.generated_tokens[step_from:step_to, :]
|
| 221 |
+
|
| 222 |
+
def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
|
| 223 |
+
if apply_mask:
|
| 224 |
+
mask = self.generated_tokens[step : step + 1, :] == -1
|
| 225 |
+
self.generated_tokens[step : step + 1, :] = torch.where(
|
| 226 |
+
mask, dec_out, self.generated_tokens[step : step + 1, :]
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
self.generated_tokens[step : step + 1, :] = dec_out
|
| 230 |
+
|
| 231 |
+
def prefill(self, dec_out: torch.Tensor, prefill_step: int):
|
| 232 |
+
length = dec_out.shape[0]
|
| 233 |
+
self.generated_tokens[0:length, :] = dec_out
|
| 234 |
+
self.prefill_step = prefill_step
|