Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.model import ExponentialMovingAverage | |
| from torch import Tensor | |
| from mmpose.registry import MODELS | |
| class ExpMomentumEMA(ExponentialMovingAverage): | |
| """Exponential moving average (EMA) with exponential momentum strategy, | |
| which is used in YOLOX. | |
| Ported from ` the implementation of MMDetection | |
| <https://github.com/open-mmlab/mmdetection/blob/3.x/mmdet/models/layers/ema.py>`_. | |
| Args: | |
| model (nn.Module): The model to be averaged. | |
| momentum (float): The momentum used for updating ema parameter. | |
| Ema's parameter are updated with the formula: | |
| `averaged_param = (1-momentum) * averaged_param + momentum * | |
| source_param`. Defaults to 0.0002. | |
| gamma (int): Use a larger momentum early in training and gradually | |
| annealing to a smaller value to update the ema model smoothly. The | |
| momentum is calculated as | |
| `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`. | |
| Defaults to 2000. | |
| interval (int): Interval between two updates. Defaults to 1. | |
| device (torch.device, optional): If provided, the averaged model will | |
| be stored on the :attr:`device`. Defaults to None. | |
| update_buffers (bool): if True, it will compute running averages for | |
| both the parameters and the buffers of the model. Defaults to | |
| False. | |
| """ | |
| def __init__(self, | |
| model: nn.Module, | |
| momentum: float = 0.0002, | |
| gamma: int = 2000, | |
| interval=1, | |
| device: Optional[torch.device] = None, | |
| update_buffers: bool = False) -> None: | |
| super().__init__( | |
| model=model, | |
| momentum=momentum, | |
| interval=interval, | |
| device=device, | |
| update_buffers=update_buffers) | |
| assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' | |
| self.gamma = gamma | |
| def avg_func(self, averaged_param: Tensor, source_param: Tensor, | |
| steps: int) -> None: | |
| """Compute the moving average of the parameters using the exponential | |
| momentum strategy. | |
| Args: | |
| averaged_param (Tensor): The averaged parameters. | |
| source_param (Tensor): The source parameters. | |
| steps (int): The number of times the parameters have been | |
| updated. | |
| """ | |
| momentum = (1 - self.momentum) * math.exp( | |
| -float(1 + steps) / self.gamma) + self.momentum | |
| averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) | |