Spaces:
Running
on
Zero
Running
on
Zero
| """Configuration management module for the Dia model. | |
| This module provides comprehensive configuration management for the Dia model, | |
| utilizing Pydantic for validation. It defines configurations for data processing, | |
| model architecture (encoder and decoder), and training settings. | |
| Key components: | |
| - DataConfig: Parameters for data loading and preprocessing. | |
| - EncoderConfig: Architecture details for the encoder module. | |
| - DecoderConfig: Architecture details for the decoder module. | |
| - ModelConfig: Combined model architecture settings. | |
| - TrainingConfig: Training hyperparameters and settings. | |
| - DiaConfig: Master configuration combining all components. | |
| """ | |
| import os | |
| from typing import Annotated | |
| from pydantic import BaseModel, BeforeValidator, Field | |
| class DataConfig(BaseModel, frozen=True): | |
| """Configuration for data loading and preprocessing. | |
| Attributes: | |
| text_length: Maximum length of text sequences (must be multiple of 128). | |
| audio_length: Maximum length of audio sequences (must be multiple of 128). | |
| channels: Number of audio channels. | |
| text_pad_value: Value used for padding text sequences. | |
| audio_eos_value: Value representing the end of audio sequences. | |
| audio_bos_value: Value representing the beginning of audio sequences. | |
| audio_pad_value: Value used for padding audio sequences. | |
| delay_pattern: List of delay values for each audio channel. | |
| """ | |
| text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = ( | |
| Field(gt=0, multiple_of=128) | |
| ) | |
| audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = ( | |
| Field(gt=0, multiple_of=128) | |
| ) | |
| channels: int = Field(default=9, gt=0, multiple_of=1) | |
| text_pad_value: int = Field(default=0) | |
| audio_eos_value: int = Field(default=1024) | |
| audio_pad_value: int = Field(default=1025) | |
| audio_bos_value: int = Field(default=1026) | |
| delay_pattern: list[Annotated[int, Field(ge=0)]] = Field( | |
| default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15] | |
| ) | |
| def __hash__(self) -> int: | |
| """Generate a hash based on all fields of the config.""" | |
| return hash( | |
| ( | |
| self.text_length, | |
| self.audio_length, | |
| self.channels, | |
| self.text_pad_value, | |
| self.audio_pad_value, | |
| self.audio_bos_value, | |
| self.audio_eos_value, | |
| tuple(self.delay_pattern), | |
| ) | |
| ) | |
| class EncoderConfig(BaseModel, frozen=True): | |
| """Configuration for the encoder component of the Dia model. | |
| Attributes: | |
| n_layer: Number of transformer layers. | |
| n_embd: Embedding dimension. | |
| n_hidden: Hidden dimension size in the MLP layers. | |
| n_head: Number of attention heads. | |
| head_dim: Dimension per attention head. | |
| """ | |
| n_layer: int = Field(gt=0) | |
| n_embd: int = Field(gt=0) | |
| n_hidden: int = Field(gt=0) | |
| n_head: int = Field(gt=0) | |
| head_dim: int = Field(gt=0) | |
| class DecoderConfig(BaseModel, frozen=True): | |
| """Configuration for the decoder component of the Dia model. | |
| Attributes: | |
| n_layer: Number of transformer layers. | |
| n_embd: Embedding dimension. | |
| n_hidden: Hidden dimension size in the MLP layers. | |
| gqa_query_heads: Number of query heads for grouped-query self-attention. | |
| kv_heads: Number of key/value heads for grouped-query self-attention. | |
| gqa_head_dim: Dimension per query head for grouped-query self-attention. | |
| cross_query_heads: Number of query heads for cross-attention. | |
| cross_head_dim: Dimension per cross-attention head. | |
| """ | |
| n_layer: int = Field(gt=0) | |
| n_embd: int = Field(gt=0) | |
| n_hidden: int = Field(gt=0) | |
| gqa_query_heads: int = Field(gt=0) | |
| kv_heads: int = Field(gt=0) | |
| gqa_head_dim: int = Field(gt=0) | |
| cross_query_heads: int = Field(gt=0) | |
| cross_head_dim: int = Field(gt=0) | |
| class ModelConfig(BaseModel, frozen=True): | |
| """Main configuration container for the Dia model architecture. | |
| Attributes: | |
| encoder: Configuration for the encoder component. | |
| decoder: Configuration for the decoder component. | |
| src_vocab_size: Size of the source (text) vocabulary. | |
| tgt_vocab_size: Size of the target (audio code) vocabulary. | |
| dropout: Dropout probability applied within the model. | |
| normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm). | |
| weight_dtype: Data type for model weights (e.g., "float32", "bfloat16"). | |
| rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE). | |
| rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE). | |
| """ | |
| encoder: EncoderConfig | |
| decoder: DecoderConfig | |
| src_vocab_size: int = Field(default=128, gt=0) | |
| tgt_vocab_size: int = Field(default=1028, gt=0) | |
| dropout: float = Field(default=0.0, ge=0.0, lt=1.0) | |
| normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0) | |
| weight_dtype: str = Field(default="float32", description="Weight precision") | |
| rope_min_timescale: int = Field( | |
| default=1, description="Timescale For global Attention" | |
| ) | |
| rope_max_timescale: int = Field( | |
| default=10_000, description="Timescale For global Attention" | |
| ) | |
| class TrainingConfig(BaseModel, frozen=True): | |
| pass | |
| class DiaConfig(BaseModel, frozen=True): | |
| """Master configuration for the Dia model. | |
| Combines all sub-configurations into a single validated object. | |
| Attributes: | |
| version: Configuration version string. | |
| model: Model architecture configuration. | |
| training: Training process configuration (precision settings). | |
| data: Data loading and processing configuration. | |
| """ | |
| version: str = Field(default="1.0") | |
| model: ModelConfig | |
| # TODO: remove training. this is just for backwards-compatability | |
| training: TrainingConfig | |
| data: DataConfig | |
| def save(self, path: str) -> None: | |
| """Save the current configuration instance to a JSON file. | |
| Ensures the parent directory exists and the file has a .json extension. | |
| Args: | |
| path: The target file path to save the configuration. | |
| Raises: | |
| ValueError: If the path is not a file with a .json extension. | |
| """ | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| config_json = self.model_dump_json(indent=2) | |
| with open(path, "w") as f: | |
| f.write(config_json) | |
| def load(cls, path: str) -> "DiaConfig | None": | |
| """Load and validate a Dia configuration from a JSON file. | |
| Args: | |
| path: The path to the configuration file. | |
| Returns: | |
| A validated DiaConfig instance if the file exists and is valid, | |
| otherwise None if the file is not found. | |
| Raises: | |
| ValueError: If the path does not point to an existing .json file. | |
| pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema. | |
| """ | |
| try: | |
| with open(path, "r") as f: | |
| content = f.read() | |
| return cls.model_validate_json(content) | |
| except FileNotFoundError: | |
| return None | |