Spaces:
Runtime error
Runtime error
| import torch | |
| from timm.utils.agc import adaptive_clip_grad | |
| def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): | |
| """ Dispatch to gradient clipping method | |
| Args: | |
| parameters (Iterable): model parameters to clip | |
| value (float): clipping value/factor/norm, mode dependant | |
| mode (str): clipping mode, one of 'norm', 'value', 'agc' | |
| norm_type (float): p-norm, default 2.0 | |
| """ | |
| if mode == 'norm': | |
| torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) | |
| elif mode == 'value': | |
| torch.nn.utils.clip_grad_value_(parameters, value) | |
| elif mode == 'agc': | |
| adaptive_clip_grad(parameters, value, norm_type=norm_type) | |
| else: | |
| assert False, f"Unknown clip mode ({mode})." | |