Spaces:
Runtime error
Runtime error
| """ Adaptive Gradient Clipping | |
| An impl of AGC, as per (https://arxiv.org/abs/2102.06171): | |
| @article{brock2021high, | |
| author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, | |
| title={High-Performance Large-Scale Image Recognition Without Normalization}, | |
| journal={arXiv preprint arXiv:}, | |
| year={2021} | |
| } | |
| Code references: | |
| * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets | |
| * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c | |
| Hacked together by / Copyright 2021 Ross Wightman | |
| """ | |
| import torch | |
| def unitwise_norm(x, norm_type=2.0): | |
| if x.ndim <= 1: | |
| return x.norm(norm_type) | |
| else: | |
| # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor | |
| # might need special cases for other weights (possibly MHA) where this may not be true | |
| return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) | |
| def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| for p in parameters: | |
| if p.grad is None: | |
| continue | |
| p_data = p.detach() | |
| g_data = p.grad.detach() | |
| max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) | |
| grad_norm = unitwise_norm(g_data, norm_type=norm_type) | |
| clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) | |
| new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) | |
| p.grad.detach().copy_(new_grads) | |