Spaces:
Runtime error
Runtime error
| # This script is extended based on https://github.com/nkolot/SPIN/blob/master/models/smpl.py | |
| import json | |
| import os | |
| import pickle | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from lib.pymafx.core import constants, path_config | |
| from lib.smplx import SMPL as _SMPL | |
| from lib.smplx import FLAMELayer, MANOLayer, SMPLXLayer | |
| from lib.smplx.body_models import SMPLXOutput | |
| from lib.smplx.lbs import ( | |
| batch_rodrigues, | |
| blend_shapes, | |
| transform_mat, | |
| vertices2joints, | |
| ) | |
| SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS | |
| SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR | |
| class ModelOutput(SMPLXOutput): | |
| smpl_joints: Optional[torch.Tensor] = None | |
| joints_J19: Optional[torch.Tensor] = None | |
| smplx_vertices: Optional[torch.Tensor] = None | |
| flame_vertices: Optional[torch.Tensor] = None | |
| lhand_vertices: Optional[torch.Tensor] = None | |
| rhand_vertices: Optional[torch.Tensor] = None | |
| lhand_joints: Optional[torch.Tensor] = None | |
| rhand_joints: Optional[torch.Tensor] = None | |
| face_joints: Optional[torch.Tensor] = None | |
| lfoot_joints: Optional[torch.Tensor] = None | |
| rfoot_joints: Optional[torch.Tensor] = None | |
| class SMPL(_SMPL): | |
| """ Extension of the official SMPL implementation to support more joints """ | |
| def __init__( | |
| self, | |
| create_betas=False, | |
| create_global_orient=False, | |
| create_body_pose=False, | |
| create_transl=False, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__( | |
| create_betas=create_betas, | |
| create_global_orient=create_global_orient, | |
| create_body_pose=create_body_pose, | |
| create_transl=create_transl, | |
| *args, | |
| **kwargs | |
| ) | |
| joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] | |
| J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) | |
| self.register_buffer( | |
| 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) | |
| ) | |
| self.joint_map = torch.tensor(joints, dtype=torch.long) | |
| # self.ModelOutput = namedtuple('ModelOutput_', ModelOutput._fields + ('smpl_joints', 'joints_J19',)) | |
| # self.ModelOutput.__new__.__defaults__ = (None,) * len(self.ModelOutput._fields) | |
| tpose_joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0)) | |
| self.register_buffer('tpose_joints', tpose_joints) | |
| def forward(self, *args, **kwargs): | |
| kwargs['get_skin'] = True | |
| smpl_output = super().forward(*args, **kwargs) | |
| extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) | |
| # smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3] | |
| vertices = smpl_output.vertices | |
| joints = torch.cat([smpl_output.joints, extra_joints], dim=1) | |
| smpl_joints = smpl_output.joints[:, :24] | |
| joints = joints[:, self.joint_map, :] # [B, 49, 3] | |
| joints_J24 = joints[:, -24:, :] | |
| joints_J19 = joints_J24[:, constants.J24_TO_J19, :] | |
| output = ModelOutput( | |
| vertices=vertices, | |
| global_orient=smpl_output.global_orient, | |
| body_pose=smpl_output.body_pose, | |
| joints=joints, | |
| joints_J19=joints_J19, | |
| smpl_joints=smpl_joints, | |
| betas=smpl_output.betas, | |
| full_pose=smpl_output.full_pose | |
| ) | |
| return output | |
| def get_global_rotation( | |
| self, | |
| global_orient: Optional[torch.Tensor] = None, | |
| body_pose: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ): | |
| ''' | |
| Forward pass for the SMPLX model | |
| Parameters | |
| ---------- | |
| global_orient: torch.tensor, optional, shape Bx3x3 | |
| If given, ignore the member variable and use it as the global | |
| rotation of the body. Useful if someone wishes to predicts this | |
| with an external model. It is expected to be in rotation matrix | |
| format. (default=None) | |
| body_pose: torch.tensor, optional, shape BxJx3x3 | |
| If given, ignore the member variable `body_pose` and use it | |
| instead. For example, it can used if someone predicts the | |
| pose of the body joints are predicted from some external model. | |
| It should be a tensor that contains joint rotations in | |
| rotation matrix format. (default=None) | |
| Returns | |
| ------- | |
| output: Global rotation matrix | |
| ''' | |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype | |
| model_vars = [global_orient, body_pose] | |
| batch_size = 1 | |
| for var in model_vars: | |
| if var is None: | |
| continue | |
| batch_size = max(batch_size, len(var)) | |
| if global_orient is None: | |
| global_orient = torch.eye(3, device=device, | |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, | |
| -1).contiguous() | |
| if body_pose is None: | |
| body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( | |
| batch_size, self.NUM_BODY_JOINTS, -1, -1 | |
| ).contiguous() | |
| # Concatenate all pose vectors | |
| full_pose = torch.cat([ | |
| global_orient.reshape(-1, 1, 3, 3), | |
| body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3) | |
| ], | |
| dim=1) | |
| rot_mats = full_pose.view(batch_size, -1, 3, 3) | |
| # Get the joints | |
| # NxJx3 array | |
| # joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0).expand(batch_size, -1, -1)) | |
| # joints = torch.unsqueeze(joints, dim=-1) | |
| joints = self.tpose_joints.expand(batch_size, -1, -1).unsqueeze(-1) | |
| rel_joints = joints.clone() | |
| rel_joints[:, 1:] -= joints[:, self.parents[1:]] | |
| transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), | |
| rel_joints.reshape(-1, 3, | |
| 1)).reshape(-1, joints.shape[1], 4, 4) | |
| transform_chain = [transforms_mat[:, 0]] | |
| for i in range(1, self.parents.shape[0]): | |
| # Subtract the joint location at the rest pose | |
| # No need for rotation, since it's identity when at rest | |
| curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) | |
| transform_chain.append(curr_res) | |
| transforms = torch.stack(transform_chain, dim=1) | |
| global_rotmat = transforms[:, :, :3, :3] | |
| # The last column of the transformations contains the posed joints | |
| posed_joints = transforms[:, :, :3, 3] | |
| return global_rotmat, posed_joints | |
| class SMPLX(SMPLXLayer): | |
| """ Extension of the official SMPLX implementation to support more functions """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def get_global_rotation( | |
| self, | |
| global_orient: Optional[torch.Tensor] = None, | |
| body_pose: Optional[torch.Tensor] = None, | |
| left_hand_pose: Optional[torch.Tensor] = None, | |
| right_hand_pose: Optional[torch.Tensor] = None, | |
| jaw_pose: Optional[torch.Tensor] = None, | |
| leye_pose: Optional[torch.Tensor] = None, | |
| reye_pose: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ): | |
| ''' | |
| Forward pass for the SMPLX model | |
| Parameters | |
| ---------- | |
| global_orient: torch.tensor, optional, shape Bx3x3 | |
| If given, ignore the member variable and use it as the global | |
| rotation of the body. Useful if someone wishes to predicts this | |
| with an external model. It is expected to be in rotation matrix | |
| format. (default=None) | |
| betas: torch.tensor, optional, shape BxN_b | |
| If given, ignore the member variable `betas` and use it | |
| instead. For example, it can used if shape parameters | |
| `betas` are predicted from some external model. | |
| (default=None) | |
| expression: torch.tensor, optional, shape BxN_e | |
| Expression coefficients. | |
| For example, it can used if expression parameters | |
| `expression` are predicted from some external model. | |
| body_pose: torch.tensor, optional, shape BxJx3x3 | |
| If given, ignore the member variable `body_pose` and use it | |
| instead. For example, it can used if someone predicts the | |
| pose of the body joints are predicted from some external model. | |
| It should be a tensor that contains joint rotations in | |
| rotation matrix format. (default=None) | |
| left_hand_pose: torch.tensor, optional, shape Bx15x3x3 | |
| If given, contains the pose of the left hand. | |
| It should be a tensor that contains joint rotations in | |
| rotation matrix format. (default=None) | |
| right_hand_pose: torch.tensor, optional, shape Bx15x3x3 | |
| If given, contains the pose of the right hand. | |
| It should be a tensor that contains joint rotations in | |
| rotation matrix format. (default=None) | |
| jaw_pose: torch.tensor, optional, shape Bx3x3 | |
| Jaw pose. It should either joint rotations in | |
| rotation matrix format. | |
| transl: torch.tensor, optional, shape Bx3 | |
| Translation vector of the body. | |
| For example, it can used if the translation | |
| `transl` is predicted from some external model. | |
| (default=None) | |
| return_verts: bool, optional | |
| Return the vertices. (default=True) | |
| return_full_pose: bool, optional | |
| Returns the full pose vector (default=False) | |
| Returns | |
| ------- | |
| output: ModelOutput | |
| A data class that contains the posed vertices and joints | |
| ''' | |
| device, dtype = self.shapedirs.device, self.shapedirs.dtype | |
| model_vars = [global_orient, body_pose, left_hand_pose, right_hand_pose, jaw_pose] | |
| batch_size = 1 | |
| for var in model_vars: | |
| if var is None: | |
| continue | |
| batch_size = max(batch_size, len(var)) | |
| if global_orient is None: | |
| global_orient = torch.eye(3, device=device, | |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, | |
| -1).contiguous() | |
| if body_pose is None: | |
| body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( | |
| batch_size, self.NUM_BODY_JOINTS, -1, -1 | |
| ).contiguous() | |
| if left_hand_pose is None: | |
| left_hand_pose = torch.eye(3, device=device, | |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, | |
| -1).contiguous() | |
| if right_hand_pose is None: | |
| right_hand_pose = torch.eye(3, device=device, | |
| dtype=dtype).view(1, 1, 3, | |
| 3).expand(batch_size, 15, -1, | |
| -1).contiguous() | |
| if jaw_pose is None: | |
| jaw_pose = torch.eye(3, device=device, | |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, | |
| -1).contiguous() | |
| if leye_pose is None: | |
| leye_pose = torch.eye(3, device=device, | |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, | |
| -1).contiguous() | |
| if reye_pose is None: | |
| reye_pose = torch.eye(3, device=device, | |
| dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, | |
| -1).contiguous() | |
| # Concatenate all pose vectors | |
| full_pose = torch.cat([ | |
| global_orient.reshape(-1, 1, 3, 3), | |
| body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), | |
| jaw_pose.reshape(-1, 1, 3, 3), | |
| leye_pose.reshape(-1, 1, 3, 3), | |
| reye_pose.reshape(-1, 1, 3, 3), | |
| left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), | |
| right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) | |
| ], | |
| dim=1) | |
| rot_mats = full_pose.view(batch_size, -1, 3, 3) | |
| # Get the joints | |
| # NxJx3 array | |
| joints = vertices2joints( | |
| self.J_regressor, | |
| self.v_template.unsqueeze(0).expand(batch_size, -1, -1) | |
| ) | |
| joints = torch.unsqueeze(joints, dim=-1) | |
| rel_joints = joints.clone() | |
| rel_joints[:, 1:] -= joints[:, self.parents[1:]] | |
| transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), | |
| rel_joints.reshape(-1, 3, | |
| 1)).reshape(-1, joints.shape[1], 4, 4) | |
| transform_chain = [transforms_mat[:, 0]] | |
| for i in range(1, self.parents.shape[0]): | |
| # Subtract the joint location at the rest pose | |
| # No need for rotation, since it's identity when at rest | |
| curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) | |
| transform_chain.append(curr_res) | |
| transforms = torch.stack(transform_chain, dim=1) | |
| global_rotmat = transforms[:, :, :3, :3] | |
| # The last column of the transformations contains the posed joints | |
| posed_joints = transforms[:, :, :3, 3] | |
| return global_rotmat, posed_joints | |
| class SMPLX_ALL(nn.Module): | |
| """ Extension of the official SMPLX implementation to support more joints """ | |
| def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs): | |
| super().__init__() | |
| numBetas = 10 | |
| self.use_face_contour = use_face_contour | |
| if all_gender: | |
| self.genders = ['male', 'female', 'neutral'] | |
| else: | |
| self.genders = ['neutral'] | |
| for gender in self.genders: | |
| assert gender in ['male', 'female', 'neutral'] | |
| self.model_dict = nn.ModuleDict({ | |
| gender: SMPLX( | |
| path_config.SMPL_MODEL_DIR, | |
| gender=gender, | |
| ext='npz', | |
| num_betas=numBetas, | |
| use_pca=False, | |
| batch_size=batch_size, | |
| use_face_contour=use_face_contour, | |
| num_pca_comps=45, | |
| **kwargs | |
| ) | |
| for gender in self.genders | |
| }) | |
| self.model_neutral = self.model_dict['neutral'] | |
| joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] | |
| J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) | |
| self.register_buffer( | |
| 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) | |
| ) | |
| self.joint_map = torch.tensor(joints, dtype=torch.long) | |
| # smplx_to_smpl.pkl, file source: https://smpl-x.is.tue.mpg.de | |
| smplx_to_smpl = pickle.load( | |
| open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb') | |
| ) | |
| self.register_buffer( | |
| 'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32) | |
| ) | |
| smpl2limb_vert_faces = get_partial_smpl('smpl') | |
| self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long() | |
| self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long() | |
| # left and right hand joint mapping | |
| smplx2lhand_joints = [ | |
| constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES | |
| ] | |
| smplx2rhand_joints = [ | |
| constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES | |
| ] | |
| self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long) | |
| self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long) | |
| # left and right foot joint mapping | |
| smplx2lfoot_joints = [ | |
| constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES | |
| ] | |
| smplx2rfoot_joints = [ | |
| constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES | |
| ] | |
| self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long) | |
| self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long) | |
| for g in self.genders: | |
| J_template = torch.einsum( | |
| 'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template] | |
| ) | |
| J_dirs = torch.einsum( | |
| 'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs] | |
| ) | |
| self.register_buffer(f'{g}_J_template', J_template) | |
| self.register_buffer(f'{g}_J_dirs', J_dirs) | |
| def forward(self, *args, **kwargs): | |
| batch_size = kwargs['body_pose'].shape[0] | |
| kwargs['get_skin'] = True | |
| if 'pose2rot' not in kwargs: | |
| kwargs['pose2rot'] = True | |
| if 'gender' not in kwargs: | |
| kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device) | |
| # pose for 55 joints: 1, 21, 15, 15, 1, 1, 1 | |
| pose_keys = [ | |
| 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', | |
| 'leye_pose', 'reye_pose' | |
| ] | |
| param_keys = ['betas'] + pose_keys | |
| if kwargs['pose2rot']: | |
| for key in pose_keys: | |
| if key in kwargs: | |
| # if key == 'left_hand_pose': | |
| # kwargs[key] += self.model_neutral.left_hand_mean | |
| # elif key == 'right_hand_pose': | |
| # kwargs[key] += self.model_neutral.right_hand_mean | |
| kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([ | |
| batch_size, -1, 3, 3 | |
| ]) | |
| if kwargs['body_pose'].shape[1] == 23: | |
| # remove hand pose in the body_pose | |
| kwargs['body_pose'] = kwargs['body_pose'][:, :21] | |
| gender_idx_list = [] | |
| smplx_vertices, smplx_joints = [], [] | |
| for gi, g in enumerate(['male', 'female', 'neutral']): | |
| gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) | |
| if len(gender_idx) == 0: | |
| continue | |
| gender_idx_list.extend([int(idx) for idx in gender_idx]) | |
| gender_kwargs = {'get_skin': kwargs['get_skin'], 'pose2rot': kwargs['pose2rot']} | |
| gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) | |
| gender_smplx_output = self.model_dict[g].forward(*args, **gender_kwargs) | |
| smplx_vertices.append(gender_smplx_output.vertices) | |
| smplx_joints.append(gender_smplx_output.joints) | |
| idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] | |
| idx_rearrange = torch.tensor(idx_rearrange).long().to(kwargs['body_pose'].device) | |
| smplx_vertices = torch.cat(smplx_vertices)[idx_rearrange] | |
| smplx_joints = torch.cat(smplx_joints)[idx_rearrange] | |
| # constants.HAND_NAMES | |
| lhand_joints = smplx_joints[:, self.smplx2lh_joint_map] | |
| rhand_joints = smplx_joints[:, self.smplx2rh_joint_map] | |
| # constants.FACIAL_LANDMARKS | |
| face_joints = smplx_joints[:, -68:] if self.use_face_contour else smplx_joints[:, -51:] | |
| # constants.FOOT_NAMES | |
| lfoot_joints = smplx_joints[:, self.smplx2lf_joint_map] | |
| rfoot_joints = smplx_joints[:, self.smplx2rf_joint_map] | |
| smpl_vertices = torch.bmm(self.smplx2smpl.expand(batch_size, -1, -1), smplx_vertices) | |
| lhand_vertices = smpl_vertices[:, self.smpl2lhand] | |
| rhand_vertices = smpl_vertices[:, self.smpl2rhand] | |
| extra_joints = vertices2joints(self.J_regressor_extra, smpl_vertices) | |
| # smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3] | |
| smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45] | |
| joints = torch.cat([smplx_j45, extra_joints], dim=1) | |
| smpl_joints = smplx_j45[:, :24] | |
| joints = joints[:, self.joint_map, :] # [B, 49, 3] | |
| joints_J24 = joints[:, -24:, :] | |
| joints_J19 = joints_J24[:, constants.J24_TO_J19, :] | |
| output = ModelOutput( | |
| vertices=smpl_vertices, | |
| smplx_vertices=smplx_vertices, | |
| lhand_vertices=lhand_vertices, | |
| rhand_vertices=rhand_vertices, | |
| # global_orient=smplx_output.global_orient, | |
| # body_pose=smplx_output.body_pose, | |
| joints=joints, | |
| joints_J19=joints_J19, | |
| smpl_joints=smpl_joints, | |
| # betas=smplx_output.betas, | |
| # full_pose=smplx_output.full_pose, | |
| lhand_joints=lhand_joints, | |
| rhand_joints=rhand_joints, | |
| lfoot_joints=lfoot_joints, | |
| rfoot_joints=rfoot_joints, | |
| face_joints=face_joints, | |
| ) | |
| return output | |
| # def make_hand_regressor(self): | |
| # # borrowed from https://github.com/mks0601/Hand4Whole_RELEASE/blob/main/common/utils/human_models.py | |
| # regressor = self.model_neutral.J_regressor.numpy() | |
| # vertex_num = self.model_neutral.J_regressor.shape[-1] | |
| # lhand_regressor = np.concatenate((regressor[[20,37,38,39],:], | |
| # np.eye(vertex_num)[5361,None], | |
| # regressor[[25,26,27],:], | |
| # np.eye(vertex_num)[4933,None], | |
| # regressor[[28,29,30],:], | |
| # np.eye(vertex_num)[5058,None], | |
| # regressor[[34,35,36],:], | |
| # np.eye(vertex_num)[5169,None], | |
| # regressor[[31,32,33],:], | |
| # np.eye(vertex_num)[5286,None])) | |
| # rhand_regressor = np.concatenate((regressor[[21,52,53,54],:], | |
| # np.eye(vertex_num)[8079,None], | |
| # regressor[[40,41,42],:], | |
| # np.eye(vertex_num)[7669,None], | |
| # regressor[[43,44,45],:], | |
| # np.eye(vertex_num)[7794,None], | |
| # regressor[[49,50,51],:], | |
| # np.eye(vertex_num)[7905,None], | |
| # regressor[[46,47,48],:], | |
| # np.eye(vertex_num)[8022,None])) | |
| # return torch.from_numpy(lhand_regressor).float(), torch.from_numpy(rhand_regressor).float() | |
| def get_tpose(self, betas=None, gender=None): | |
| kwargs = {} | |
| if betas is None: | |
| betas = torch.zeros(1, 10).to(self.J_regressor_extra.device) | |
| kwargs['betas'] = betas | |
| batch_size = kwargs['betas'].shape[0] | |
| device = kwargs['betas'].device | |
| if gender is None: | |
| kwargs['gender'] = 2 * torch.ones(batch_size).to(device) | |
| else: | |
| kwargs['gender'] = gender | |
| param_keys = ['betas'] | |
| gender_idx_list = [] | |
| smplx_joints = [] | |
| for gi, g in enumerate(['male', 'female', 'neutral']): | |
| gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) | |
| if len(gender_idx) == 0: | |
| continue | |
| gender_idx_list.extend([int(idx) for idx in gender_idx]) | |
| gender_kwargs = {} | |
| gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) | |
| J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes( | |
| gender_kwargs['betas'], getattr(self, f'{g}_J_dirs') | |
| ) | |
| smplx_joints.append(J) | |
| idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] | |
| idx_rearrange = torch.tensor(idx_rearrange).long().to(device) | |
| smplx_joints = torch.cat(smplx_joints)[idx_rearrange] | |
| return smplx_joints | |
| class MANO(MANOLayer): | |
| """ Extension of the official MANO implementation to support more joints """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, *args, **kwargs): | |
| if 'pose2rot' not in kwargs: | |
| kwargs['pose2rot'] = True | |
| pose_keys = ['global_orient', 'right_hand_pose'] | |
| batch_size = kwargs['global_orient'].shape[0] | |
| if kwargs['pose2rot']: | |
| for key in pose_keys: | |
| if key in kwargs: | |
| kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([ | |
| batch_size, -1, 3, 3 | |
| ]) | |
| kwargs['hand_pose'] = kwargs.pop('right_hand_pose') | |
| mano_output = super().forward(*args, **kwargs) | |
| th_verts = mano_output.vertices | |
| th_jtr = mano_output.joints | |
| # https://github.com/hassony2/manopth/blob/master/manopth/manolayer.py#L248-L260 | |
| # In addition to MANO reference joints we sample vertices on each finger | |
| # to serve as finger tips | |
| tips = th_verts[:, [745, 317, 445, 556, 673]] | |
| th_jtr = torch.cat([th_jtr, tips], 1) | |
| # Reorder joints to match visualization utilities | |
| th_jtr = th_jtr[:, | |
| [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] | |
| output = ModelOutput( | |
| rhand_vertices=th_verts, | |
| rhand_joints=th_jtr, | |
| ) | |
| return output | |
| class FLAME(FLAMELayer): | |
| """ Extension of the official FLAME implementation to support more joints """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, *args, **kwargs): | |
| if 'pose2rot' not in kwargs: | |
| kwargs['pose2rot'] = True | |
| pose_keys = ['global_orient', 'jaw_pose', 'leye_pose', 'reye_pose'] | |
| batch_size = kwargs['global_orient'].shape[0] | |
| if kwargs['pose2rot']: | |
| for key in pose_keys: | |
| if key in kwargs: | |
| kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([ | |
| batch_size, -1, 3, 3 | |
| ]) | |
| flame_output = super().forward(*args, **kwargs) | |
| output = ModelOutput( | |
| flame_vertices=flame_output.vertices, | |
| face_joints=flame_output.joints[:, 5:], | |
| ) | |
| return output | |
| class SMPL_Family(): | |
| def __init__(self, model_type='smpl', *args, **kwargs): | |
| if model_type == 'smpl': | |
| self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs) | |
| elif model_type == 'smplx': | |
| self.model = SMPLX_ALL(*args, **kwargs) | |
| elif model_type == 'mano': | |
| self.model = MANO( | |
| model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs | |
| ) | |
| elif model_type == 'flame': | |
| self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs) | |
| def __call__(self, *args, **kwargs): | |
| return self.model(*args, **kwargs) | |
| def get_tpose(self, *args, **kwargs): | |
| return self.model.get_tpose(*args, **kwargs) | |
| # def to(self, device): | |
| # self.model.to(device) | |
| # def cuda(self, device=None): | |
| # if device is None: | |
| # self.model.cuda() | |
| # else: | |
| # self.model.cuda(device) | |
| def get_smpl_faces(): | |
| smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1) | |
| return smpl.faces | |
| def get_smplx_faces(): | |
| smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1) | |
| return smplx.faces | |
| def get_mano_faces(hand_type='right'): | |
| assert hand_type in ['right', 'left'] | |
| is_rhand = True if hand_type == 'right' else False | |
| mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=is_rhand) | |
| return mano.faces | |
| def get_flame_faces(): | |
| flame = FLAME(SMPL_MODEL_DIR, batch_size=1) | |
| return flame.faces | |
| def get_model_faces(type='smpl'): | |
| if type == 'smpl': | |
| return get_smpl_faces() | |
| elif type == 'smplx': | |
| return get_smplx_faces() | |
| elif type == 'mano': | |
| return get_mano_faces() | |
| elif type == 'flame': | |
| return get_flame_faces() | |
| def get_model_tpose(type='smpl'): | |
| if type == 'smpl': | |
| return get_smpl_tpose() | |
| elif type == 'smplx': | |
| return get_smplx_tpose() | |
| elif type == 'mano': | |
| return get_mano_tpose() | |
| elif type == 'flame': | |
| return get_flame_tpose() | |
| def get_smpl_tpose(): | |
| smpl = SMPL( | |
| create_betas=True, | |
| create_global_orient=True, | |
| create_body_pose=True, | |
| model_path=SMPL_MODEL_DIR, | |
| batch_size=1 | |
| ) | |
| vertices = smpl().vertices[0] | |
| return vertices.detach() | |
| def get_smpl_tpose_joint(): | |
| smpl = SMPL( | |
| create_betas=True, | |
| create_global_orient=True, | |
| create_body_pose=True, | |
| model_path=SMPL_MODEL_DIR, | |
| batch_size=1 | |
| ) | |
| tpose_joint = smpl().smpl_joints[0] | |
| return tpose_joint.detach() | |
| def get_smplx_tpose(): | |
| smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) | |
| vertices = smplx().vertices[0] | |
| return vertices | |
| def get_smplx_tpose_joint(): | |
| smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) | |
| tpose_joint = smplx().joints[0] | |
| return tpose_joint | |
| def get_mano_tpose(): | |
| mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True) | |
| vertices = mano(global_orient=torch.zeros(1, 3), | |
| right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0] | |
| return vertices | |
| def get_flame_tpose(): | |
| flame = FLAME(SMPL_MODEL_DIR, batch_size=1) | |
| vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0] | |
| return vertices | |
| def get_part_joints(smpl_joints): | |
| batch_size = smpl_joints.shape[0] | |
| # part_joints = torch.zeros().to(smpl_joints.device) | |
| one_seg_pairs = [(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), | |
| (14, 17)] | |
| two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)] | |
| one_seg_pairs.extend(two_seg_pairs) | |
| single_joints = [(10), (11), (15), (22), (23)] | |
| part_joints = [] | |
| for j_p in one_seg_pairs: | |
| new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True) | |
| part_joints.append(new_joint) | |
| for j_p in single_joints: | |
| part_joints.append(smpl_joints[:, j_p:j_p + 1]) | |
| part_joints = torch.cat(part_joints, dim=1) | |
| return part_joints | |
| def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): | |
| body_model_faces = get_model_faces(body_model) | |
| body_model_num_verts = len(get_model_tpose(body_model)) | |
| part_vert_faces = {} | |
| for part in ['lhand', 'rhand', 'face', 'arm', 'forearm', 'larm', 'rarm', 'lwrist', 'rwrist']: | |
| part_vid_fname = '{}/{}_{}_vids.npz'.format(path_config.PARTIAL_MESH_DIR, body_model, part) | |
| if os.path.exists(part_vid_fname): | |
| part_vids = np.load(part_vid_fname) | |
| part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']} | |
| else: | |
| if part in ['lhand', 'rhand']: | |
| with open( | |
| os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb' | |
| ) as json_file: | |
| smplx_mano_id = pickle.load(json_file) | |
| with open( | |
| os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb' | |
| ) as json_file: | |
| smplx_smpl_id = pickle.load(json_file) | |
| smplx_tpose = get_smplx_tpose() | |
| smpl_tpose = np.matmul(smplx_smpl_id['matrix'], smplx_tpose) | |
| if part == 'lhand': | |
| mano_vert = smplx_tpose[smplx_mano_id['left_hand']] | |
| elif part == 'rhand': | |
| mano_vert = smplx_tpose[smplx_mano_id['right_hand']] | |
| smpl2mano_id = [] | |
| for vert in mano_vert: | |
| v_diff = smpl_tpose - vert | |
| v_diff = torch.sum(v_diff * v_diff, dim=1) | |
| v_closest = torch.argmin(v_diff) | |
| smpl2mano_id.append(int(v_closest)) | |
| smpl2mano_vids = np.array(smpl2mano_id).astype(np.long) | |
| mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left' | |
| ).astype(np.long) | |
| np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces) | |
| part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces} | |
| elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']: | |
| with open( | |
| os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), | |
| 'rb' | |
| ) as json_file: | |
| smplx_part_id = json.load(json_file) | |
| # main_body_part = list(smplx_part_id.keys()) | |
| # print('main_body_part', main_body_part) | |
| if part == 'face': | |
| selected_body_part = ['head'] | |
| elif part == 'arm': | |
| selected_body_part = [ | |
| 'rightHand', | |
| 'leftArm', | |
| 'leftShoulder', | |
| 'rightShoulder', | |
| 'rightArm', | |
| 'leftHandIndex1', | |
| 'rightHandIndex1', | |
| 'leftForeArm', | |
| 'rightForeArm', | |
| 'leftHand', | |
| ] | |
| # selected_body_part = ['rightHand', 'leftArm', 'rightArm', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',] | |
| elif part == 'forearm': | |
| selected_body_part = [ | |
| 'rightHand', | |
| 'leftHandIndex1', | |
| 'rightHandIndex1', | |
| 'leftForeArm', | |
| 'rightForeArm', | |
| 'leftHand', | |
| ] | |
| elif part == 'arm_eval': | |
| selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm'] | |
| elif part == 'larm': | |
| # selected_body_part = ['leftArm', 'leftForeArm'] | |
| selected_body_part = ['leftForeArm'] | |
| elif part == 'rarm': | |
| # selected_body_part = ['rightArm', 'rightForeArm'] | |
| selected_body_part = ['rightForeArm'] | |
| part_body_idx = [] | |
| for k in selected_body_part: | |
| part_body_idx.extend(smplx_part_id[k]) | |
| part_body_fid = [] | |
| for f_id, face in enumerate(body_model_faces): | |
| if any(f in part_body_idx for f in face): | |
| part_body_fid.append(f_id) | |
| smpl2head_vids = np.unique(body_model_faces[part_body_fid]).astype(np.long) | |
| mesh_vid_raw = np.arange(body_model_num_verts) | |
| head_vid_new = np.arange(len(smpl2head_vids)) | |
| mesh_vid_raw[smpl2head_vids] = head_vid_new | |
| head_faces = body_model_faces[part_body_fid] | |
| head_faces = mesh_vid_raw[head_faces].astype(np.long) | |
| np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces) | |
| part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces} | |
| elif part in ['lwrist', 'rwrist']: | |
| if body_model == 'smplx': | |
| body_model_verts = get_smplx_tpose() | |
| tpose_joint = get_smplx_tpose_joint() | |
| elif body_model == 'smpl': | |
| body_model_verts = get_smpl_tpose() | |
| tpose_joint = get_smpl_tpose_joint() | |
| wrist_joint = tpose_joint[20] if part == 'lwrist' else tpose_joint[21] | |
| dist = 0.005 | |
| wrist_vids = [] | |
| for vid, vt in enumerate(body_model_verts): | |
| v_j_dist = torch.sum((vt - wrist_joint)**2) | |
| if v_j_dist < dist: | |
| wrist_vids.append(vid) | |
| wrist_vids = np.array(wrist_vids) | |
| part_body_fid = [] | |
| for f_id, face in enumerate(body_model_faces): | |
| if any(f in wrist_vids for f in face): | |
| part_body_fid.append(f_id) | |
| smpl2part_vids = np.unique(body_model_faces[part_body_fid]).astype(np.long) | |
| mesh_vid_raw = np.arange(body_model_num_verts) | |
| part_vid_new = np.arange(len(smpl2part_vids)) | |
| mesh_vid_raw[smpl2part_vids] = part_vid_new | |
| part_faces = body_model_faces[part_body_fid] | |
| part_faces = mesh_vid_raw[part_faces].astype(np.long) | |
| np.savez(part_vid_fname, vids=smpl2part_vids, faces=part_faces) | |
| part_vert_faces[part] = {'vids': smpl2part_vids, 'faces': part_faces} | |
| # import trimesh | |
| # mesh = trimesh.Trimesh(vertices=body_model_verts[smpl2part_vids], faces=part_faces, process=False) | |
| # mesh.export(f'results/smplx_{part}.obj') | |
| # mesh = trimesh.Trimesh(vertices=body_model_verts, faces=body_model_faces, process=False) | |
| # mesh.export(f'results/smplx_model.obj') | |
| return part_vert_faces | |