Spaces:
Paused
Paused
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Part of the code is from | |
| # `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and | |
| # `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py` | |
| # Modified by Yue Zhao | |
| # The original code is under MIT License | |
| import torch | |
| import torch.distributed as dist | |
| from typing import Tuple | |
| def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: | |
| """ | |
| For some backends, such as NCCL, communication only works if the | |
| tensor is on the GPU. This helper function converts to the correct | |
| device and returns the tensor + original device. | |
| """ | |
| orig_device = "cpu" if not tensor.is_cuda else "gpu" | |
| if ( | |
| torch.distributed.is_available() | |
| and torch.distributed.get_backend() == torch.distributed.Backend.NCCL | |
| and not tensor.is_cuda | |
| ): | |
| tensor = tensor.cuda() | |
| return (tensor, orig_device) | |
| def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: | |
| """ | |
| For some backends, such as NCCL, communication only works if the | |
| tensor is on the GPU. This converts the tensor back to original device. | |
| """ | |
| if tensor.is_cuda and orig_device == "cpu": | |
| tensor = tensor.cpu() | |
| return tensor | |
| def is_distributed_training_run() -> bool: | |
| return ( | |
| torch.distributed.is_available() | |
| and torch.distributed.is_initialized() | |
| and (torch.distributed.get_world_size() > 1) | |
| ) | |
| class GatherLayer(torch.autograd.Function): | |
| """ | |
| Gather tensors from all workers with support for backward propagation: | |
| This implementation does not cut the gradients as torch.distributed.all_gather does. | |
| """ | |
| def forward(ctx, x): | |
| output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] | |
| dist.all_gather(output, x) | |
| return tuple(output) | |
| def backward(ctx, *grads): | |
| all_gradients = torch.stack(grads) | |
| dist.all_reduce(all_gradients) | |
| return all_gradients[dist.get_rank()] | |
| def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Similar to classy_vision.generic.distributed_util.gather_from_all | |
| except that it does not cut the gradients | |
| """ | |
| if tensor.ndim == 0: | |
| # 0 dim tensors cannot be gathered. so unsqueeze | |
| tensor = tensor.unsqueeze(0) | |
| if is_distributed_training_run(): | |
| tensor, orig_device = convert_to_distributed_tensor(tensor) | |
| gathered_tensors = GatherLayer.apply(tensor) | |
| gathered_tensors = [ | |
| convert_to_normal_tensor(_tensor, orig_device) | |
| for _tensor in gathered_tensors | |
| ] | |
| else: | |
| gathered_tensors = [tensor] | |
| gathered_tensor = torch.cat(gathered_tensors, 0) | |
| return gathered_tensor | |