Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import contextlib | |
| from typing import Optional | |
| import torch | |
| from fairseq.dataclass.configs import DistributedTrainingConfig | |
| from fairseq.distributed import utils as dist_utils | |
| try: | |
| from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | |
| has_FSDP = True | |
| except ImportError: | |
| FSDP = torch.nn.Module | |
| has_FSDP = False | |
| class FullyShardedDataParallel(FSDP): | |
| """ | |
| A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some | |
| fairseq-specific checkpoint saving/loading logic. | |
| Args: | |
| use_sharded_state (bool): if True, then ``state_dict`` will return | |
| ``FSDP.local_state_dict`` and ``load_state_dict`` will call | |
| ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will | |
| return the full model weights on data parallel rank 0 (empty on | |
| other ranks) and ``load_state_dict`` will broadcast model weights | |
| from rank 0 to other ranks. | |
| """ | |
| def __init__(self, *args, use_sharded_state: bool = False, **kwargs): | |
| if not has_FSDP: | |
| raise ImportError( | |
| "Cannot find FullyShardedDataParallel. " | |
| "Please install fairscale with: pip install fairscale" | |
| ) | |
| super().__init__(*args, **kwargs) | |
| self.use_sharded_state = use_sharded_state | |
| def unwrapped_module(self) -> torch.nn.Module: | |
| if self.flatten_parameters: | |
| return self.module.module | |
| else: | |
| return self.module | |
| def state_dict(self, destination=None, prefix="", keep_vars=False): | |
| if self.use_sharded_state: | |
| return super().local_state_dict( | |
| destination=destination, prefix=prefix, keep_vars=keep_vars | |
| ) | |
| else: | |
| if self.rank == 0: | |
| return super().state_dict( | |
| destination=destination, prefix=prefix, keep_vars=keep_vars | |
| ) | |
| else: | |
| # We must call state_dict() due to use of communication | |
| # primitives. But we don't use the result. | |
| super().state_dict() | |
| return destination or {} | |
| def load_state_dict(self, state_dict, strict=True, model_cfg=None): | |
| if self.use_sharded_state: | |
| return super().load_local_state_dict(state_dict, strict=strict) | |
| else: | |
| state_dict = dist_utils.broadcast_object( | |
| state_dict, src_rank=0, group=self.process_group | |
| ) | |
| return super().load_state_dict(state_dict, strict=strict) | |
| class DummyProcessGroup: | |
| def __init__(self, rank: int, size: int): | |
| self._rank = rank | |
| self._size = size | |
| def rank(self) -> int: | |
| return self._rank | |
| def size(self) -> int: | |
| return self._size | |
| def fsdp_enable_wrap(cfg: DistributedTrainingConfig): | |
| try: | |
| from fairscale.nn import enable_wrap | |
| except ImportError: | |
| raise ImportError( | |
| "Cannot find FullyShardedDataParallel. " | |
| "Please install fairscale with: pip install fairscale" | |
| ) | |
| if cfg.memory_efficient_fp16: | |
| assert cfg.fp16 # memory_efficient_fp16 should imply fp16 | |
| group = dist_utils.get_data_parallel_group() | |
| if group is None and cfg.distributed_world_size == 1: | |
| group = DummyProcessGroup(rank=0, size=1) | |
| fsdp_config = { | |
| "process_group": group, | |
| "reshard_after_forward": not cfg.no_reshard_after_forward, | |
| "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, | |
| "fp32_reduce_scatter": cfg.fp32_reduce_scatter, | |
| "flatten_parameters": not cfg.not_fsdp_flatten_parameters, | |
| "cpu_offload": cfg.cpu_offload, | |
| "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, | |
| "bucket_cap_mb": cfg.bucket_cap_mb, | |
| "state_dict_device": torch.device("cpu"), # reduce GPU mem usage | |
| } | |
| with enable_wrap( | |
| wrapper_cls=FullyShardedDataParallel, | |
| use_sharded_state=cfg.use_sharded_state, | |
| **fsdp_config, | |
| ): | |
| yield | |
| def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): | |
| """ | |
| Helper to wrap layers/modules in FSDP. This falls back to a no-op if | |
| fairscale is not available. | |
| Args: | |
| module (nn.Module): module to (maybe) wrap | |
| min_num_params (int, Optional): minimum number of layer params to wrap | |
| """ | |
| try: | |
| from fairscale.nn import wrap | |
| if min_num_params is not None: | |
| num_params = sum(p.numel() for p in module.parameters()) | |
| if num_params >= min_num_params: | |
| return wrap(module, **kwargs) | |
| else: | |
| return module | |
| else: | |
| return wrap(module, **kwargs) | |
| except ImportError: | |
| return module | |