Spaces:
Configuration error
Configuration error
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import torch | |
| from models.tts.fastspeech2.fs2_trainer import FastSpeech2Trainer | |
| from models.tts.vits.vits_trainer import VITSTrainer | |
| from models.tts.valle.valle_trainer import VALLETrainer | |
| from models.tts.naturalspeech2.ns2_trainer import NS2Trainer | |
| from models.tts.valle_v2.valle_ar_trainer import ValleARTrainer as VALLE_V2_AR | |
| from models.tts.valle_v2.valle_nar_trainer import ValleNARTrainer as VALLE_V2_NAR | |
| from models.tts.jets.jets_trainer import JetsTrainer | |
| from utils.util import load_config | |
| def build_trainer(args, cfg): | |
| supported_trainer = { | |
| "FastSpeech2": FastSpeech2Trainer, | |
| "VITS": VITSTrainer, | |
| "VALLE": VALLETrainer, | |
| "NaturalSpeech2": NS2Trainer, | |
| "VALLE_V2_AR": VALLE_V2_AR, | |
| "VALLE_V2_NAR": VALLE_V2_NAR, | |
| "Jets": JetsTrainer, | |
| } | |
| trainer_class = supported_trainer[cfg.model_type] | |
| trainer = trainer_class(args, cfg) | |
| return trainer | |
| def cuda_relevant(deterministic=False): | |
| torch.cuda.empty_cache() | |
| # TF32 on Ampere and above | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Deterministic | |
| torch.backends.cudnn.deterministic = deterministic | |
| torch.backends.cudnn.benchmark = not deterministic | |
| torch.use_deterministic_algorithms(deterministic) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", | |
| default="config.json", | |
| help="json files for configurations.", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=1234, | |
| help="random seed", | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--exp_name", | |
| type=str, | |
| default="exp_name", | |
| help="A specific name to note the experiment", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--resume", action="store_true", help="The model name to restore" | |
| ) | |
| parser.add_argument( | |
| "--test", action="store_true", default=False, help="Test the model" | |
| ) | |
| parser.add_argument( | |
| "--log_level", default="warning", help="logging level (debug, info, warning)" | |
| ) | |
| parser.add_argument( | |
| "--resume_type", | |
| type=str, | |
| default="resume", | |
| help="Resume training or finetuning.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_path", | |
| type=str, | |
| default=None, | |
| help="Checkpoint for resume training or finetuning.", | |
| ) | |
| parser.add_argument( | |
| "--resume_from_ckpt_path", | |
| type=str, | |
| default="", | |
| help="Checkpoint for resume training or finetuning.", | |
| ) | |
| # VALLETrainer.add_arguments(parser) | |
| args = parser.parse_args() | |
| cfg = load_config(args.config) | |
| # Data Augmentation | |
| if hasattr(cfg, "preprocess"): | |
| if hasattr(cfg.preprocess, "data_augment"): | |
| if ( | |
| type(cfg.preprocess.data_augment) == list | |
| and len(cfg.preprocess.data_augment) > 0 | |
| ): | |
| new_datasets_list = [] | |
| for dataset in cfg.preprocess.data_augment: | |
| new_datasets = [ | |
| ( | |
| f"{dataset}_pitch_shift" | |
| if cfg.preprocess.use_pitch_shift | |
| else None | |
| ), | |
| ( | |
| f"{dataset}_formant_shift" | |
| if cfg.preprocess.use_formant_shift | |
| else None | |
| ), | |
| ( | |
| f"{dataset}_equalizer" | |
| if cfg.preprocess.use_equalizer | |
| else None | |
| ), | |
| ( | |
| f"{dataset}_time_stretch" | |
| if cfg.preprocess.use_time_stretch | |
| else None | |
| ), | |
| ] | |
| new_datasets_list.extend(filter(None, new_datasets)) | |
| cfg.dataset.extend(new_datasets_list) | |
| print("experiment name: ", args.exp_name) | |
| # # CUDA settings | |
| cuda_relevant() | |
| # Build trainer | |
| print(f"Building {cfg.model_type} trainer") | |
| trainer = build_trainer(args, cfg) | |
| print(f"Start training {cfg.model_type} model") | |
| if args.test: | |
| trainer.test_loop() | |
| else: | |
| trainer.train_loop() | |
| if __name__ == "__main__": | |
| main() | |