STAR / fairseq /model_parallel /megatron_trainer.py
Yixuan Li
add fairseq folder
85ba398
raw
history blame
2.57 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
from fairseq.dataclass.configs import FairseqConfig
from fairseq.distributed import utils as distributed_utils
from fairseq.trainer import Trainer
try:
from fairseq.model_parallel.megatron.mpu import (
get_data_parallel_rank,
get_data_parallel_world_size,
get_model_parallel_src_rank,
get_cuda_rng_tracker,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class MegatronTrainer(Trainer):
"""Main class for model parallel with data parallel training."""
def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
super().__init__(cfg, task, model, criterion, **kwargs)
def clip_grad_norm(self, clip_norm):
def _aggregate_model_parallel_grad_norm(total_norm):
total_norm = total_norm**2
distributed_utils.all_reduce(
total_norm, group=distributed_utils.get_model_parallel_group()
)
total_norm = total_norm**0.5
return total_norm
return self.optimizer.clip_grad_norm(
clip_norm,
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
super().save_checkpoint(filename, extra_state)
def load_checkpoint(
self,
filename,
reset_optimizer=False,
reset_lr_scheduler=False,
optimizer_overrides=None,
reset_meters=False,
):
extra_state = super().load_checkpoint(
filename,
reset_optimizer=reset_optimizer,
reset_lr_scheduler=reset_lr_scheduler,
optimizer_overrides=optimizer_overrides,
reset_meters=reset_meters,
)
if extra_state is not None and "rng_tracker_states" in extra_state:
get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
return extra_state