Spaces:
Running
on
Zero
Running
on
Zero
| from pathlib import Path | |
| import torch, os | |
| from tqdm import tqdm | |
| import pickle | |
| import argparse | |
| import logging, datetime | |
| import torch.distributed as dist | |
| from config import MyParser | |
| from steps import trainer | |
| from copy_codebase import copy_codebase | |
| def world_info_from_env(): | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| global_rank = int(os.environ["RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| return local_rank, global_rank, world_size | |
| if __name__ == "__main__": | |
| formatter = ( | |
| "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" | |
| ) | |
| logging.basicConfig(format=formatter, level=logging.INFO) | |
| torch.cuda.empty_cache() | |
| args = MyParser().parse_args() | |
| exp_dir = Path(args.exp_dir) | |
| exp_dir.mkdir(exist_ok=True, parents=True) | |
| logging.info(f"exp_dir: {str(exp_dir)}") | |
| if args.resume and (os.path.exists("%s/bundle.pth" % args.exp_dir) or os.path.exists("%s/bundle_prev.pth" % args.exp_dir)): | |
| if not os.path.exists("%s/bundle.pth" % args.exp_dir): | |
| os.system(f"cp {args.exp_dir}/bundle_prev.pth {args.exp_dir}/bundle.pth") | |
| resume = args.resume | |
| assert(bool(args.exp_dir)) | |
| with open("%s/args.pkl" % args.exp_dir, "rb") as f: | |
| old_args = pickle.load(f) | |
| new_args = vars(args) | |
| old_args = vars(old_args) | |
| for key in new_args: | |
| if key not in old_args or old_args[key] != new_args[key]: | |
| old_args[key] = new_args[key] | |
| args = argparse.Namespace(**old_args) | |
| args.resume = resume | |
| else: | |
| args.resume = False | |
| with open("%s/args.pkl" % args.exp_dir, "wb") as f: | |
| pickle.dump(args, f) | |
| # make timeout longer (for generation) | |
| timeout = datetime.timedelta(seconds=7200) # 60 minutes | |
| if args.multinodes: | |
| _local_rank, _, _ = world_info_from_env() | |
| dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout) | |
| else: | |
| dist.init_process_group(backend='nccl', init_method='env://', timeout=timeout) | |
| if args.local_wandb: | |
| os.environ["WANDB_MODE"] = "offline" | |
| rank = dist.get_rank() | |
| if rank == 0: | |
| logging.info(args) | |
| logging.info(f"exp_dir: {str(exp_dir)}") | |
| world_size = dist.get_world_size() | |
| local_rank = int(_local_rank) if args.multinodes else rank | |
| num_devices= torch.cuda.device_count() | |
| logging.info(f"{local_rank=}, {rank=}, {world_size=}, {type(local_rank)=}, {type(rank)=}, {type(world_size)=}") | |
| for device_idx in range(num_devices): | |
| device_name = torch.cuda.get_device_name(device_idx) | |
| logging.info(f"Device {device_idx}: {device_name}") | |
| torch.cuda.set_device(local_rank) | |
| if rank == 0: | |
| user_dir = os.path.expanduser("~") | |
| codebase_name = "VoiceStar" | |
| now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
| copy_codebase(os.path.join(user_dir, codebase_name), os.path.join(exp_dir, f"{codebase_name}_{now}"), max_size_mb=5, gitignore_path=os.path.join(user_dir, codebase_name, ".gitignore")) | |
| my_trainer = trainer.Trainer(args, world_size, rank, local_rank) | |
| my_trainer.train() |