Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import datetime | |
| import logging | |
| import time | |
| import random | |
| import torch | |
| import torch.distributed as dist | |
| from maskrcnn_benchmark.utils.comm import get_world_size, synchronize, broadcast_data | |
| from maskrcnn_benchmark.utils.metric_logger import MetricLogger | |
| from maskrcnn_benchmark.utils.ema import ModelEma | |
| def reduce_loss_dict(loss_dict): | |
| """ | |
| Reduce the loss dictionary from all processes so that process with rank | |
| 0 has the averaged results. Returns a dict with the same fields as | |
| loss_dict, after reduction. | |
| """ | |
| world_size = get_world_size() | |
| if world_size < 2: | |
| return loss_dict | |
| with torch.no_grad(): | |
| loss_names = [] | |
| all_losses = [] | |
| for k in sorted(loss_dict.keys()): | |
| loss_names.append(k) | |
| all_losses.append(loss_dict[k]) | |
| all_losses = torch.stack(all_losses, dim=0) | |
| dist.reduce(all_losses, dst=0) | |
| if dist.get_rank() == 0: | |
| # only main process gets accumulated, so only divide by | |
| # world_size in this case | |
| all_losses /= world_size | |
| reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} | |
| return reduced_losses | |
| def do_train( | |
| cfg, | |
| model, | |
| data_loader, | |
| optimizer, | |
| scheduler, | |
| checkpointer, | |
| device, | |
| checkpoint_period, | |
| arguments, | |
| rngs=None | |
| ): | |
| logger = logging.getLogger("maskrcnn_benchmark.trainer") | |
| logger.info("Start training") | |
| meters = MetricLogger(delimiter=" ") | |
| max_iter = len(data_loader) | |
| start_iter = arguments["iteration"] | |
| model.train() | |
| model_ema = None | |
| if cfg.SOLVER.MODEL_EMA>0: | |
| model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA) | |
| start_training_time = time.time() | |
| end = time.time() | |
| for iteration, (images, targets, _) in enumerate(data_loader, start_iter): | |
| if any(len(target) < 1 for target in targets): | |
| logger.error("Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" ) | |
| continue | |
| data_time = time.time() - end | |
| iteration = iteration + 1 | |
| arguments["iteration"] = iteration | |
| images = images.to(device) | |
| targets = [target.to(device) for target in targets] | |
| # synchronize rngs | |
| if rngs is None: | |
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
| mix_nums = model.module.mix_nums | |
| else: | |
| mix_nums = model.mix_nums | |
| rngs = [random.randint(0, mix-1) for mix in mix_nums] | |
| rngs = broadcast_data(rngs) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| loss_dict = model(images, targets, rngs) | |
| losses = sum(loss for loss in loss_dict.values()) | |
| # reduce losses over all GPUs for logging purposes | |
| loss_dict_reduced = reduce_loss_dict(loss_dict) | |
| losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | |
| meters.update(loss=losses_reduced, **loss_dict_reduced) | |
| optimizer.zero_grad() | |
| losses.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| if model_ema is not None: | |
| model_ema.update(model) | |
| arguments["model_ema"] = model_ema.state_dict() | |
| batch_time = time.time() - end | |
| end = time.time() | |
| meters.update(time=batch_time, data=data_time) | |
| eta_seconds = meters.time.global_avg * (max_iter - iteration) | |
| eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
| if iteration % 20 == 0 or iteration == max_iter: | |
| logger.info( | |
| meters.delimiter.join( | |
| [ | |
| "eta: {eta}", | |
| "iter: {iter}", | |
| "{meters}", | |
| "lr: {lr:.6f}", | |
| "max mem: {memory:.0f}", | |
| ] | |
| ).format( | |
| eta=eta_string, | |
| iter=iteration, | |
| meters=str(meters), | |
| lr=optimizer.param_groups[0]["lr"], | |
| memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, | |
| ) | |
| ) | |
| if iteration % checkpoint_period == 0: | |
| checkpointer.save("model_{:07d}".format(iteration), **arguments) | |
| if iteration == max_iter: | |
| if model_ema is not None: | |
| model.load_state_dict(model_ema.state_dict()) | |
| checkpointer.save("model_final", **arguments) | |
| total_training_time = time.time() - start_training_time | |
| total_time_str = str(datetime.timedelta(seconds=total_training_time)) | |
| logger.info( | |
| "Total training time: {} ({:.4f} s / it)".format( | |
| total_time_str, total_training_time / (max_iter) | |
| ) | |
| ) | |