Spaces:
Running
on
Zero
Running
on
Zero
| from torch import nn | |
| from transformers.modeling_utils import PreTrainedModel | |
| from .configuration_higgs_audio import HiggsAudioConfig | |
| class HiggsAudioPreTrainedModel(PreTrainedModel): | |
| config_class = HiggsAudioConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = [] | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| def _init_weights(self, module): | |
| std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std | |
| if isinstance(module, (nn.Linear, nn.Conv1d)): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |