Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.distributed as dist | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def setup_for_distributed(is_master): | |
| import warnings | |
| builtin_warn = warnings.warn | |
| def warn(*args, **kwargs): | |
| force = kwargs.pop("force", False) | |
| if is_master or force: | |
| builtin_warn(*args, **kwargs) | |
| # Log warnings only once | |
| warnings.warn = warn | |
| warnings.simplefilter("once", UserWarning) | |
| if not is_master: | |
| logging.disable() | |
| def is_dist_avail_and_initialized(): | |
| if not dist.is_available(): | |
| return False | |
| if not dist.is_initialized(): | |
| return False | |
| return True | |
| def get_world_size(): | |
| if not is_dist_avail_and_initialized(): | |
| return 1 | |
| return dist.get_world_size() | |
| def get_rank(): | |
| if not is_dist_avail_and_initialized(): | |
| return 0 | |
| return dist.get_rank() | |
| def is_main_process(): | |
| return get_rank() == 0 | |
| def save_on_master(*args, **kwargs): | |
| if is_main_process(): | |
| torch.save(*args, **kwargs) | |
| def is_port_in_use(port): | |
| import socket | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| return s.connect_ex(('localhost', port)) == 0 | |
| def init_distributed_mode(args): | |
| if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
| # job started by torch.distributed.launch | |
| args.rank = int(os.environ["RANK"]) | |
| args.world_size = int(os.environ['WORLD_SIZE']) | |
| args.gpu = int(os.environ['LOCAL_RANK']) | |
| elif 'SLURM_PROCID' in os.environ: | |
| # local rank on the current node / global rank | |
| local_rank = int(os.environ['SLURM_LOCALID']) | |
| global_rank = int(os.environ['SLURM_PROCID']) | |
| # number of processes / GPUs per node | |
| world_size = int(os.environ["SLURM_NNODES"]) * \ | |
| int(os.environ["SLURM_TASKS_PER_NODE"][0]) | |
| print(world_size) | |
| args.rank = global_rank | |
| args.gpu = local_rank | |
| args.world_size = world_size | |
| else: | |
| logger.info('Not using distributed mode') | |
| args.distributed = False | |
| return | |
| args.distributed = True | |
| torch.cuda.set_device(args.gpu) | |
| args.dist_backend = 'nccl' | |
| if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node | |
| dist_port = int(args.dist_url.split(":")[-1]) | |
| while is_port_in_use(dist_port): | |
| dist_port += 10 | |
| args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)]) | |
| print(args.dist_url) | |
| logger.info('| distributed init (rank {}): {}'.format( | |
| args.rank, args.dist_url)) | |
| if "SLURM_JOB_ID" in os.environ: | |
| logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}") | |
| torch.distributed.init_process_group( | |
| backend=args.dist_backend, init_method=args.dist_url, | |
| world_size=args.world_size, rank=args.rank) | |
| torch.distributed.barrier() | |
| setup_for_distributed(args.rank == 0) | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py | |
| class GatherLayer(torch.autograd.Function): | |
| """ | |
| Gather tensors from all workers with support for backward propagation: | |
| This implementation does not cut the gradients as torch.distributed.all_gather does. | |
| """ | |
| def forward(ctx, x): | |
| output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] | |
| dist.all_gather(output, x) | |
| return tuple(output) | |
| def backward(ctx, *grads): | |
| all_gradients = torch.stack(grads) | |
| dist.all_reduce(all_gradients) | |
| return all_gradients[dist.get_rank()] | |
| # copied from megavlt | |
| def gather_tensor_along_batch_with_backward(tensor, dim=0): | |
| world_size = get_world_size() | |
| if world_size < 2: | |
| return tensor | |
| tensor_list = GatherLayer.apply(tensor) | |
| tensor_list = torch.cat(tensor_list, dim=dim) | |
| return tensor_list | |
| def gather_tensor_along_batch(tensor, dim=0): | |
| """ | |
| Performs all_gather operation on the provided tensors. | |
| *** Warning ***: torch.distributed.all_gather has no gradient. | |
| """ | |
| world_size = get_world_size() | |
| if world_size < 2: | |
| return tensor | |
| with torch.no_grad(): | |
| tensor_list = [] | |
| for _ in range(world_size): | |
| tensor_list.append(torch.zeros_like(tensor)) | |
| dist.all_gather(tensor_list, tensor) | |
| tensor_list = torch.cat(tensor_list, dim=dim) | |
| return tensor_list | |