Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import functools | |
| import random | |
| import unittest | |
| from multiprocessing import Manager | |
| import torch | |
| import torch.nn as nn | |
| from fairseq import optim | |
| from fairseq.distributed import utils as distributed_utils | |
| from omegaconf import OmegaConf | |
| class Model(nn.Module): | |
| def __init__(self, input_size, output_size): | |
| super(Model, self).__init__() | |
| self.fc = nn.Linear(input_size, output_size) | |
| def forward(self, input): | |
| output = self.fc(input) | |
| return output | |
| def setup_model_loss_criterion(cfg, args, rank, is_cuda): | |
| """ | |
| setup model, criterion and optimizer based on input args | |
| """ | |
| args.distributed_rank = rank | |
| cfg.distributed_training.distributed_rank = args.distributed_rank | |
| if cfg.distributed_training.distributed_world_size > 1: | |
| distributed_utils.distributed_init(cfg) | |
| torch.manual_seed(1) | |
| model = Model(args.input_size, args.nb_classes) | |
| loss_fn = nn.CrossEntropyLoss() | |
| if is_cuda: | |
| model = model.cuda() | |
| loss_fn = loss_fn.cuda() | |
| optimizer = optim.sgd.SGD(args, model.parameters()) | |
| optimizer = optim.FairseqBMUF( | |
| cfg=cfg.bmuf, | |
| optimizer=optimizer | |
| ) | |
| return model, loss_fn, optimizer | |
| def train_step(input, target, model, loss_fn, optimizer, **unused): | |
| """Do forward, backward and parameter update.""" | |
| model.train() | |
| output = model(input) | |
| loss = loss_fn(output, target) | |
| optimizer.backward(loss) | |
| optimizer.step() | |
| def single_gpu_training(cfg, args, rank, iterations, shared_results): | |
| is_cuda = torch.cuda.is_available() | |
| if is_cuda: | |
| torch.cuda.set_device(rank) | |
| model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda) | |
| for _ in range(iterations): | |
| input = torch.randn(1, args.input_size) | |
| target = torch.empty(args.batch_size, dtype=torch.long).random_(args.nb_classes) | |
| if is_cuda: | |
| input = input.cuda() | |
| target = target.cuda() | |
| train_step(input, target, model, loss_fn, optimizer) | |
| results = [] | |
| for param in model.parameters(): | |
| if len(results) == 0: | |
| results = param.flatten().cpu().data | |
| else: | |
| results = torch.cat((results, param.flatten().cpu().data), 0) | |
| shared_results[rank] = results | |
| def setup_args(): | |
| args = argparse.Namespace() | |
| args.global_sync_iter = 20 | |
| args.block_momentum = 0.875 | |
| args.block_lr = 0.5 | |
| args.input_size = 5 | |
| args.nb_classes = 2 | |
| args.batch_size = 1 | |
| args.lr = [1e-3] | |
| args.momentum = 0 | |
| args.weight_decay = 0 | |
| args.warmup_iterations = 0 | |
| args.use_nbm = True | |
| args.average_sync = True | |
| args.global_sync_iter = 1 | |
| args.model_parallel_size = 1 | |
| args.distributed_backend = "gloo" | |
| args.distributed_world_size = 2 | |
| port = random.randint(10000, 20000) | |
| args.distributed_init_method = "tcp://localhost:{port}".format(port=port) | |
| args.distributed_init_host = "localhost" | |
| args.distributed_port = port + 1 | |
| args.local_world_size = args.distributed_world_size | |
| cfg = OmegaConf.create() | |
| cfg.optimization = OmegaConf.create() | |
| cfg.common = OmegaConf.create() | |
| cfg.distributed_training = OmegaConf.create() | |
| cfg.dataset = OmegaConf.create() | |
| cfg.bmuf = OmegaConf.create() | |
| cfg.optimizer = OmegaConf.create() | |
| cfg.bmuf.global_sync_iter = args.global_sync_iter | |
| cfg.bmuf.block_momentum = args.block_momentum | |
| cfg.bmuf.block_lr = args.block_lr | |
| cfg.dataset.batch_size = args.batch_size | |
| cfg.optimization.lr = args.lr | |
| cfg.optimizer.momentum = args.momentum | |
| cfg.optimizer.weight_decay = args.weight_decay | |
| cfg.bmuf.warmup_iterations = args.warmup_iterations | |
| cfg.bmuf.use_nbm = args.use_nbm | |
| cfg.bmuf.average_sync = args.average_sync | |
| cfg.common.model_parallel_size = args.model_parallel_size | |
| cfg.distributed_training.distributed_backend = args.distributed_backend | |
| cfg.distributed_training.distributed_world_size = args.distributed_world_size | |
| cfg.bmuf.distributed_world_size = args.distributed_world_size | |
| cfg.distributed_training.distributed_init_method = args.distributed_init_method | |
| cfg.distributed_training.distributed_port = args.distributed_port | |
| return cfg, args | |
| class TestBMUF(unittest.TestCase): | |
| def bmuf_process(self, cfg, args, iterations): | |
| processes = [] | |
| results = Manager().dict() | |
| torch.multiprocessing.spawn( | |
| fn=functools.partial(single_gpu_training, cfg, args), | |
| args=(iterations, results), | |
| nprocs=args.distributed_world_size, | |
| join=True, | |
| ) | |
| return results | |
| def test_bmuf_sync(self): | |
| # Train model for 1 iteration and do bmuf sync without doing warmup | |
| cfg, args = setup_args() | |
| iterations = 1 | |
| results = self.bmuf_process(cfg, args, iterations) | |
| # Make sure params in both machines are same | |
| assert len(results) == 2 | |
| self.assertAlmostEqual(results[0], results[1]) | |
| def test_warmup_sync(self): | |
| # Train model for 20 iteration and do warmup sync without doing bmuf sync | |
| cfg, args = setup_args() | |
| args.warmup_iterations = 20 | |
| cfg.bmuf.warmup_iterations = args.warmup_iterations | |
| iterations = 20 | |
| results = self.bmuf_process(cfg, args, iterations) | |
| # Make sure params in both machines are same | |
| assert len(results) == 2 | |
| self.assertAlmostEqual(results[0], results[1]) | |
| def test_warmup_sync_bmuf_sync(self): | |
| # Train model for 25 iteration and do warmup sync after 20 iteration | |
| # and bmuf sync after 25 iteration | |
| cfg, args = setup_args() | |
| args.warmup_iterations = 20 | |
| args.global_sync_iter = 5 | |
| cfg.bmuf.warmup_iterations = args.warmup_iterations | |
| cfg.bmuf.global_sync_iter = args.global_sync_iter | |
| iterations = 25 | |
| results = self.bmuf_process(cfg, args, iterations) | |
| # Make sure params in both machines are same | |
| assert len(results) == 2 | |
| self.assertAlmostEqual(results[0], results[1]) | |
| def test_single_gpu_bmuf(self): | |
| # Train model for 5 iterations and use GPU 1 | |
| cfg, args = setup_args() | |
| args.distributed_world_size = 1 | |
| args.warmup_iterations = 5 | |
| cfg.distributed_training.distributed_world_size = args.distributed_world_size | |
| cfg.bmuf.distributed_world_size = args.distributed_world_size | |
| cfg.bmuf.warmup_iterations = args.warmup_iterations | |
| iterations = 20 | |
| results = self.bmuf_process(cfg, args, iterations) | |
| assert len(results) == 1 | |
| def assertAlmostEqual(self, t1, t2): | |
| self.assertEqual(t1.size(), t2.size(), "size mismatch") | |
| self.assertLess((t1 - t2).abs().max(), 1e-4) | |
| if __name__ == "__main__": | |
| unittest.main() | |