Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.nn import Parameter | |
| from ..models.factory import create_model_from_config | |
| def create_training_wrapper_from_config(model_config, model): | |
| model_type = model_config.get('model_type', None) | |
| assert model_type is not None, 'model_type must be specified in model config' | |
| training_config = model_config.get('training', None) | |
| assert training_config is not None, 'training config must be specified in model config' | |
| if model_type == 'autoencoder': | |
| from .autoencoders import AutoencoderTrainingWrapper | |
| ema_copy = None | |
| if training_config.get("use_ema", False): | |
| ema_copy = create_model_from_config(model_config) | |
| ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once | |
| # Copy each weight to the ema copy | |
| for name, param in model.state_dict().items(): | |
| if isinstance(param, Parameter): | |
| # backwards compatibility for serialized parameters | |
| param = param.data | |
| ema_copy.state_dict()[name].copy_(param) | |
| use_ema = training_config.get("use_ema", False) | |
| latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) | |
| teacher_model = training_config.get("teacher_model", None) | |
| if teacher_model is not None: | |
| teacher_model = create_model_from_config(teacher_model) | |
| teacher_model = teacher_model.eval().requires_grad_(False) | |
| teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) | |
| if teacher_model_ckpt is not None: | |
| teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) | |
| else: | |
| raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") | |
| return AutoencoderTrainingWrapper( | |
| model, | |
| lr=training_config["learning_rate"], | |
| warmup_steps=training_config.get("warmup_steps", 0), | |
| encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), | |
| sample_rate=model_config["sample_rate"], | |
| loss_config=training_config.get("loss_configs", None), | |
| optimizer_configs=training_config.get("optimizer_configs", None), | |
| use_ema=use_ema, | |
| ema_copy=ema_copy if use_ema else None, | |
| force_input_mono=training_config.get("force_input_mono", False), | |
| latent_mask_ratio=latent_mask_ratio, | |
| teacher_model=teacher_model | |
| ) | |
| elif model_type == 'diffusion_uncond': | |
| from .diffusion import DiffusionUncondTrainingWrapper | |
| return DiffusionUncondTrainingWrapper( | |
| model, | |
| lr=training_config["learning_rate"], | |
| pre_encoded=training_config.get("pre_encoded", False), | |
| ) | |
| elif model_type == 'diffusion_infill': | |
| from .diffusion import DiffusionInfillTrainingWrapper | |
| return DiffusionInfillTrainingWrapper( | |
| model, | |
| lr=training_config.get("learning_rate", None), | |
| optimizer_configs=training_config.get("optimizer_configs", None), | |
| pre_encoded=training_config.get("pre_encoded", False), | |
| frac_lengths_mask=training_config.get("frac_lengths_mask", (0.7, 1.)), | |
| min_span_len=training_config.get("min_span_len", 10), | |
| timestep_sampler = training_config.get("timestep_sampler", "uniform"), | |
| ctx_drop = training_config.get("ctx_drop", 0.1), | |
| r_drop = training_config.get("r_drop", 0.0) | |
| ) | |
| elif model_type == 'diffusion_cond' or model_type == 'mm_diffusion_cond': | |
| from .diffusion import DiffusionCondTrainingWrapper | |
| return DiffusionCondTrainingWrapper( | |
| model, | |
| lr=training_config.get("learning_rate", None), | |
| mask_padding=training_config.get("mask_padding", False), | |
| mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), | |
| use_ema = training_config.get("use_ema", True), | |
| log_loss_info=training_config.get("log_loss_info", False), | |
| optimizer_configs=training_config.get("optimizer_configs", None), | |
| pre_encoded=training_config.get("pre_encoded", False), | |
| diffusion_objective=training_config.get("diffusion_objective","v"), | |
| cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), | |
| timestep_sampler = training_config.get("timestep_sampler", "uniform"), | |
| max_mask_segments = training_config.get("max_mask_segments", 0) | |
| ) | |
| elif model_type == 'diffusion_prior': | |
| from .diffusion import DiffusionPriorTrainingWrapper | |
| from ..models.diffusion_prior import PriorType | |
| ema_copy = create_model_from_config(model_config) | |
| # Copy each weight to the ema copy | |
| for name, param in model.state_dict().items(): | |
| if isinstance(param, Parameter): | |
| # backwards compatibility for serialized parameters | |
| param = param.data | |
| ema_copy.state_dict()[name].copy_(param) | |
| prior_type = training_config.get("prior_type", "mono_stereo") | |
| if prior_type == "mono_stereo": | |
| prior_type_enum = PriorType.MonoToStereo | |
| else: | |
| raise ValueError(f"Unknown prior type: {prior_type}") | |
| return DiffusionPriorTrainingWrapper( | |
| model, | |
| lr=training_config["learning_rate"], | |
| ema_copy=ema_copy, | |
| prior_type=prior_type_enum, | |
| log_loss_info=training_config.get("log_loss_info", False), | |
| use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), | |
| ) | |
| elif model_type == 'diffusion_cond_inpaint': | |
| from .diffusion import DiffusionCondInpaintTrainingWrapper | |
| return DiffusionCondInpaintTrainingWrapper( | |
| model, | |
| lr=training_config.get("learning_rate", None), | |
| max_mask_segments = training_config.get("max_mask_segments", 10), | |
| log_loss_info=training_config.get("log_loss_info", False), | |
| optimizer_configs=training_config.get("optimizer_configs", None), | |
| use_ema=training_config.get("use_ema", True), | |
| pre_encoded=training_config.get("pre_encoded", False), | |
| cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), | |
| timestep_sampler = training_config.get("timestep_sampler", "uniform") | |
| ) | |
| elif model_type == 'diffusion_autoencoder' : | |
| from .diffusion import DiffusionAutoencoderTrainingWrapper | |
| ema_copy = create_model_from_config(model_config) | |
| # Copy each weight to the ema copy | |
| for name, param in model.state_dict().items(): | |
| if isinstance(param, Parameter): | |
| # backwards compatibility for serialized parameters | |
| param = param.data | |
| ema_copy.state_dict()[name].copy_(param) | |
| return DiffusionAutoencoderTrainingWrapper( | |
| model, | |
| ema_copy=ema_copy, | |
| lr=training_config["learning_rate"], | |
| use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) | |
| ) | |
| elif model_type == 'lm': | |
| from .lm import AudioLanguageModelTrainingWrapper | |
| ema_copy = create_model_from_config(model_config) | |
| for name, param in model.state_dict().items(): | |
| if isinstance(param, Parameter): | |
| # backwards compatibility for serialized parameters | |
| param = param.data | |
| ema_copy.state_dict()[name].copy_(param) | |
| return AudioLanguageModelTrainingWrapper( | |
| model, | |
| ema_copy=ema_copy, | |
| lr=training_config.get("learning_rate", None), | |
| use_ema=training_config.get("use_ema", False), | |
| optimizer_configs=training_config.get("optimizer_configs", None), | |
| pre_encoded=training_config.get("pre_encoded", False), | |
| ) | |
| else: | |
| raise NotImplementedError(f'Unknown model type: {model_type}') | |
| def create_demo_callback_from_config(model_config, **kwargs): | |
| model_type = model_config.get('model_type', None) | |
| assert model_type is not None, 'model_type must be specified in model config' | |
| training_config = model_config.get('training', None) | |
| assert training_config is not None, 'training config must be specified in model config' | |
| demo_config = training_config.get("demo", {}) | |
| if model_type == 'autoencoder': | |
| from .autoencoders import AutoencoderDemoCallback | |
| return AutoencoderDemoCallback( | |
| demo_every=demo_config.get("demo_every", 2000), | |
| sample_size=model_config["sample_size"], | |
| sample_rate=model_config["sample_rate"], | |
| **kwargs | |
| ) | |
| elif model_type == 'diffusion_uncond': | |
| from .diffusion import DiffusionUncondDemoCallback | |
| return DiffusionUncondDemoCallback( | |
| demo_every=demo_config.get("demo_every", 2000), | |
| demo_steps=demo_config.get("demo_steps", 250), | |
| sample_rate=model_config["sample_rate"] | |
| ) | |
| elif model_type == 'diffusion_infill': | |
| from .diffusion import DiffusionInfillDemoCallback | |
| return DiffusionInfillDemoCallback( | |
| demo_every=demo_config.get("demo_every", 2000), | |
| demo_steps=demo_config.get("demo_steps", 250), | |
| sample_rate=model_config["sample_rate"], | |
| **kwargs | |
| ) | |
| elif model_type == "diffusion_autoencoder": | |
| from .diffusion import DiffusionAutoencoderDemoCallback | |
| return DiffusionAutoencoderDemoCallback( | |
| demo_every=demo_config.get("demo_every", 2000), | |
| demo_steps=demo_config.get("demo_steps", 250), | |
| sample_size=model_config["sample_size"], | |
| sample_rate=model_config["sample_rate"], | |
| **kwargs | |
| ) | |
| elif model_type == "diffusion_prior": | |
| from .diffusion import DiffusionPriorDemoCallback | |
| return DiffusionPriorDemoCallback( | |
| demo_every=demo_config.get("demo_every", 2000), | |
| demo_steps=demo_config.get("demo_steps", 250), | |
| sample_size=model_config["sample_size"], | |
| sample_rate=model_config["sample_rate"], | |
| **kwargs | |
| ) | |
| elif model_type == "diffusion_cond" or model_type == 'mm_diffusion_cond': | |
| from .diffusion import DiffusionCondDemoCallback | |
| return DiffusionCondDemoCallback( | |
| demo_every=demo_config.get("demo_every", 2000), | |
| sample_size=model_config["sample_size"], | |
| sample_rate=model_config["sample_rate"], | |
| demo_steps=demo_config.get("demo_steps", 250), | |
| num_demos=demo_config["num_demos"], | |
| demo_cfg_scales=demo_config["demo_cfg_scales"], | |
| demo_conditioning=demo_config.get("demo_cond", {}), | |
| demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), | |
| display_audio_cond=demo_config.get("display_audio_cond", False), | |
| ) | |
| elif model_type == "diffusion_cond_inpaint": | |
| from .diffusion import DiffusionCondInpaintDemoCallback | |
| return DiffusionCondInpaintDemoCallback( | |
| demo_every=demo_config.get("demo_every", 2000), | |
| sample_size=model_config["sample_size"], | |
| sample_rate=model_config["sample_rate"], | |
| demo_steps=demo_config.get("demo_steps", 250), | |
| demo_cfg_scales=demo_config["demo_cfg_scales"], | |
| **kwargs | |
| ) | |
| elif model_type == "lm": | |
| from .lm import AudioLanguageModelDemoCallback | |
| return AudioLanguageModelDemoCallback( | |
| demo_every=demo_config.get("demo_every", 2000), | |
| sample_size=model_config["sample_size"], | |
| sample_rate=model_config["sample_rate"], | |
| demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), | |
| demo_conditioning=demo_config.get("demo_cond", None), | |
| num_demos=demo_config.get("num_demos", 8), | |
| **kwargs | |
| ) | |
| else: | |
| raise NotImplementedError(f'Unknown model type: {model_type}') |