Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates | |
| import torch | |
| from typing import Any, Dict, List, Set | |
| from detectron2.solver.build import maybe_add_gradient_clipping | |
| def build_optimizer(cfg, model): | |
| norm_module_types = ( | |
| torch.nn.BatchNorm1d, | |
| torch.nn.BatchNorm2d, | |
| torch.nn.BatchNorm3d, | |
| torch.nn.SyncBatchNorm, | |
| torch.nn.GroupNorm, | |
| torch.nn.InstanceNorm1d, | |
| torch.nn.InstanceNorm2d, | |
| torch.nn.InstanceNorm3d, | |
| torch.nn.LayerNorm, | |
| torch.nn.LocalResponseNorm, | |
| ) | |
| params: List[Dict[str, Any]] = [] | |
| memo: Set[torch.nn.parameter.Parameter] = set() | |
| for module in model.modules(): | |
| for key, value in module.named_parameters(recurse=False): | |
| if not value.requires_grad: | |
| continue | |
| # Avoid duplicating parameters | |
| if value in memo: | |
| continue | |
| memo.add(value) | |
| lr = cfg.SOLVER.BASE_LR | |
| weight_decay = cfg.SOLVER.WEIGHT_DECAY | |
| if isinstance(module, norm_module_types) and (cfg.SOLVER.WEIGHT_DECAY_NORM is not None): | |
| weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM | |
| elif key == "bias": | |
| if (cfg.SOLVER.BIAS_LR_FACTOR is not None): | |
| lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR | |
| if (cfg.SOLVER.WEIGHT_DECAY_BIAS is not None): | |
| weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS | |
| # these params do not need weight decay at all | |
| # TODO parameterize these in configs instead. | |
| if key in ['priors_dims_per_cat', 'priors_z_scales', 'priors_z_stats']: | |
| weight_decay = 0.0 | |
| params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] | |
| if cfg.SOLVER.TYPE == 'sgd': | |
| optimizer = torch.optim.SGD( | |
| params, | |
| cfg.SOLVER.BASE_LR, | |
| momentum=cfg.SOLVER.MOMENTUM, | |
| nesterov=cfg.SOLVER.NESTEROV, | |
| weight_decay=cfg.SOLVER.WEIGHT_DECAY | |
| ) | |
| elif cfg.SOLVER.TYPE == 'adam': | |
| optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, eps=1e-02) | |
| elif cfg.SOLVER.TYPE == 'adam+amsgrad': | |
| optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, amsgrad=True, eps=1e-02) | |
| elif cfg.SOLVER.TYPE == 'adamw': | |
| optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR, eps=1e-02) | |
| elif cfg.SOLVER.TYPE == 'adamw+amsgrad': | |
| optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR, amsgrad=True, eps=1e-02) | |
| else: | |
| raise ValueError('{} is not supported as an optimizer.'.format(cfg.SOLVER.TYPE)) | |
| optimizer = maybe_add_gradient_clipping(cfg, optimizer) | |
| return optimizer | |
| def freeze_bn(network): | |
| for _, module in network.named_modules(): | |
| if isinstance(module, torch.nn.BatchNorm2d): | |
| module.eval() | |
| module.track_running_stats = False | |