Miroslav Purkrabek
add code
a249588
raw
history blame
9.37 kB
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Tuple, Union
import torch
from mmengine.dist import get_world_size
from mmengine.logging import print_log
from mmengine.model import BaseModel
from torch import Tensor
from mmpose.datasets.datasets.utils import parse_pose_metainfo
from mmpose.models.utils import check_and_update_config
from mmpose.registry import MODELS
from mmpose.utils.typing import (ConfigType, ForwardResults, OptConfigType,
Optional, OptMultiConfig, OptSampleList,
SampleList)
class BasePoseEstimator(BaseModel, metaclass=ABCMeta):
"""Base class for pose estimators.
Args:
data_preprocessor (dict | ConfigDict, optional): The pre-processing
config of :class:`BaseDataPreprocessor`. Defaults to ``None``
init_cfg (dict | ConfigDict): The model initialization config.
Defaults to ``None``
use_syncbn (bool): whether to use SyncBatchNorm. Defaults to False.
metainfo (dict): Meta information for dataset, such as keypoints
definition and properties. If set, the metainfo of the input data
batch will be overridden. For more details, please refer to
https://mmpose.readthedocs.io/en/latest/user_guides/
prepare_datasets.html#create-a-custom-dataset-info-
config-file-for-the-dataset. Defaults to ``None``
"""
_version = 2
def __init__(self,
backbone: ConfigType,
neck: OptConfigType = None,
head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
use_syncbn: bool = False,
init_cfg: OptMultiConfig = None,
metainfo: Optional[dict] = None):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.metainfo = self._load_metainfo(metainfo)
self.train_cfg = train_cfg if train_cfg else {}
self.test_cfg = test_cfg if test_cfg else {}
self.backbone = MODELS.build(backbone)
# the PR #2108 and #2126 modified the interface of neck and head.
# The following function automatically detects outdated
# configurations and updates them accordingly, while also providing
# clear and concise information on the changes made.
neck, head = check_and_update_config(neck, head)
if neck is not None:
self.neck = MODELS.build(neck)
if head is not None:
self.head = MODELS.build(head)
self.head.test_cfg = self.test_cfg.copy()
# Register the hook to automatically convert old version state dicts
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
# TODO: Waiting for mmengine support
if use_syncbn and get_world_size() > 1:
torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
print_log('Using SyncBatchNorm()', 'current')
def switch_to_deploy(self):
"""Switch the sub-modules to deploy mode."""
for name, layer in self.named_modules():
if layer == self:
continue
if callable(getattr(layer, 'switch_to_deploy', None)):
print_log(f'module {name} has been switched to deploy mode',
'current')
layer.switch_to_deploy(self.test_cfg)
@property
def with_neck(self) -> bool:
"""bool: whether the pose estimator has a neck."""
return hasattr(self, 'neck') and self.neck is not None
@property
def with_head(self) -> bool:
"""bool: whether the pose estimator has a head."""
return hasattr(self, 'head') and self.head is not None
@staticmethod
def _load_metainfo(metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
metainfo (dict): Raw data of pose meta information.
Returns:
dict: Parsed meta information.
"""
if metainfo is None:
return None
if not isinstance(metainfo, dict):
raise TypeError(
f'metainfo should be a dict, but got {type(metainfo)}')
metainfo = parse_pose_metainfo(metainfo)
return metainfo
def forward(self,
inputs: torch.Tensor,
data_samples: OptSampleList,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: 'tensor', 'predict' and 'loss':
- 'tensor': Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- 'predict': Forward and return the predictions, which are fully
processed to a list of :obj:`PoseDataSample`.
- 'loss': Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general
data_samples (list[:obj:`PoseDataSample`], optional): The
annotation of every sample. Defaults to ``None``
mode (str): Set the forward mode and return value type. Defaults
to ``'tensor'``
Returns:
The return type depends on ``mode``.
- If ``mode='tensor'``, return a tensor or a tuple of tensors
- If ``mode='predict'``, return a list of :obj:``PoseDataSample``
that contains the pose predictions
- If ``mode='loss'``, return a dict of tensor(s) which is the loss
function value
"""
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
# use customed metainfo to override the default metainfo
if self.metainfo is not None:
for data_sample in data_samples:
data_sample.set_metainfo(self.metainfo)
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(inputs)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode.')
@abstractmethod
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
@abstractmethod
def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
def _forward(self,
inputs: Tensor,
data_samples: OptSampleList = None
) -> Union[Tensor, Tuple[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
Returns:
Union[Tensor | Tuple[Tensor]]: forward output of the network.
"""
x = self.extract_feat(inputs)
if self.with_head:
x = self.head.forward(x)
return x
def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
inputs (Tensor): Image tensor with shape (N, C, H ,W).
Returns:
tuple[Tensor]: Multi-level features that may have various
resolutions.
"""
x = self.backbone(inputs)
if self.with_neck:
x = self.neck(x)
return x
def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
**kwargs):
"""A hook function to.
1) convert old-version state dict of
:class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a
compatible format of :class:`HeatmapHead`.
2) remove the weights in data_preprocessor to avoid warning
`unexpected key in source state_dict: ...`. These weights are
initialized with given arguments and remain same during training
and inference.
The hook will be automatically registered during initialization.
"""
keys = list(state_dict.keys())
# remove the keys in data_preprocessor to avoid warning
for k in keys:
if k in ('data_preprocessor.mean', 'data_preprocessor.std'):
del state_dict[k]
version = local_meta.get('version', None)
if version and version >= self._version:
return
# convert old-version state dict
for k in keys:
if 'keypoint_head' in k:
v = state_dict.pop(k)
k = k.replace('keypoint_head', 'head')
state_dict[k] = v