Spaces:
Sleeping
Sleeping
| """ PyTorch MADGRAD optimizer | |
| MADGRAD: https://arxiv.org/abs/2101.11075 | |
| Code from: https://github.com/facebookresearch/madgrad | |
| """ | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import TYPE_CHECKING, Any, Callable, Optional | |
| import torch | |
| import torch.optim | |
| if TYPE_CHECKING: | |
| from torch.optim.optimizer import _params_t | |
| else: | |
| _params_t = Any | |
| class MADGRAD(torch.optim.Optimizer): | |
| """ | |
| MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic | |
| Optimization. | |
| .. _MADGRAD: https://arxiv.org/abs/2101.11075 | |
| MADGRAD is a general purpose optimizer that can be used in place of SGD or | |
| Adam may converge faster and generalize better. Currently GPU-only. | |
| Typically, the same learning rate schedule that is used for SGD or Adam may | |
| be used. The overall learning rate is not comparable to either method and | |
| should be determined by a hyper-parameter sweep. | |
| MADGRAD requires less weight decay than other methods, often as little as | |
| zero. Momentum values used for SGD or Adam's beta1 should work here also. | |
| On sparse problems both weight_decay and momentum should be set to 0. | |
| Arguments: | |
| params (iterable): | |
| Iterable of parameters to optimize or dicts defining parameter groups. | |
| lr (float): | |
| Learning rate (default: 1e-2). | |
| momentum (float): | |
| Momentum value in the range [0,1) (default: 0.9). | |
| weight_decay (float): | |
| Weight decay, i.e. a L2 penalty (default: 0). | |
| eps (float): | |
| Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). | |
| """ | |
| def __init__( | |
| self, | |
| params: _params_t, | |
| lr: float = 1e-2, | |
| momentum: float = 0.9, | |
| weight_decay: float = 0, | |
| eps: float = 1e-6, | |
| decoupled_decay: bool = False, | |
| ): | |
| if momentum < 0 or momentum >= 1: | |
| raise ValueError(f"Momentum {momentum} must be in the range [0,1]") | |
| if lr <= 0: | |
| raise ValueError(f"Learning rate {lr} must be positive") | |
| if weight_decay < 0: | |
| raise ValueError(f"Weight decay {weight_decay} must be non-negative") | |
| if eps < 0: | |
| raise ValueError(f"Eps must be non-negative") | |
| defaults = dict( | |
| lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) | |
| super().__init__(params, defaults) | |
| def supports_memory_efficient_fp16(self) -> bool: | |
| return False | |
| def supports_flat_params(self) -> bool: | |
| return True | |
| def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: | |
| """Performs a single optimization step. | |
| Arguments: | |
| closure (callable, optional): A closure that reevaluates the model and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| eps = group['eps'] | |
| lr = group['lr'] + eps | |
| weight_decay = group['weight_decay'] | |
| momentum = group['momentum'] | |
| ck = 1 - momentum | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| grad = p.grad | |
| if momentum != 0.0 and grad.is_sparse: | |
| raise RuntimeError("momentum != 0 is not compatible with sparse gradients") | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state['step'] = 0 | |
| state['grad_sum_sq'] = torch.zeros_like(p) | |
| state['s'] = torch.zeros_like(p) | |
| if momentum != 0: | |
| state['x0'] = torch.clone(p).detach() | |
| state['step'] += 1 | |
| grad_sum_sq = state['grad_sum_sq'] | |
| s = state['s'] | |
| lamb = lr * math.sqrt(state['step']) | |
| # Apply weight decay | |
| if weight_decay != 0: | |
| if group['decoupled_decay']: | |
| p.mul_(1.0 - group['lr'] * weight_decay) | |
| else: | |
| if grad.is_sparse: | |
| raise RuntimeError("weight_decay option is not compatible with sparse gradients") | |
| grad.add_(p, alpha=weight_decay) | |
| if grad.is_sparse: | |
| grad = grad.coalesce() | |
| grad_val = grad._values() | |
| p_masked = p.sparse_mask(grad) | |
| grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) | |
| s_masked = s.sparse_mask(grad) | |
| # Compute x_0 from other known quantities | |
| rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) | |
| x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1) | |
| # Dense + sparse op | |
| grad_sq = grad * grad | |
| grad_sum_sq.add_(grad_sq, alpha=lamb) | |
| grad_sum_sq_masked.add_(grad_sq, alpha=lamb) | |
| rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) | |
| s.add_(grad, alpha=lamb) | |
| s_masked._values().add_(grad_val, alpha=lamb) | |
| # update masked copy of p | |
| p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) | |
| # Copy updated masked p to dense p using an add operation | |
| p_masked._values().add_(p_kp1_masked_vals, alpha=-1) | |
| p.add_(p_masked, alpha=-1) | |
| else: | |
| if momentum == 0: | |
| # Compute x_0 from other known quantities | |
| rms = grad_sum_sq.pow(1 / 3).add_(eps) | |
| x0 = p.addcdiv(s, rms, value=1) | |
| else: | |
| x0 = state['x0'] | |
| # Accumulate second moments | |
| grad_sum_sq.addcmul_(grad, grad, value=lamb) | |
| rms = grad_sum_sq.pow(1 / 3).add_(eps) | |
| # Update s | |
| s.add_(grad, alpha=lamb) | |
| # Step | |
| if momentum == 0: | |
| p.copy_(x0.addcdiv(s, rms, value=-1)) | |
| else: | |
| z = x0.addcdiv(s, rms, value=-1) | |
| # p is a moving average of z | |
| p.mul_(1 - ck).add_(z, alpha=ck) | |
| return loss | |