Spaces:
Runtime error
Runtime error
| """ | |
| original from https://github.com/vchoutas/smplx | |
| modified by Vassilis and Yao | |
| """ | |
| import pickle | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from .lbs import ( | |
| JointsFromVerticesSelector, | |
| Struct, | |
| find_dynamic_lmk_idx_and_bcoords, | |
| lbs, | |
| to_np, | |
| to_tensor, | |
| vertices2landmarks, | |
| ) | |
| # SMPLX | |
| J14_NAMES = [ | |
| "right_ankle", | |
| "right_knee", | |
| "right_hip", | |
| "left_hip", | |
| "left_knee", | |
| "left_ankle", | |
| "right_wrist", | |
| "right_elbow", | |
| "right_shoulder", | |
| "left_shoulder", | |
| "left_elbow", | |
| "left_wrist", | |
| "neck", | |
| "head", | |
| ] | |
| SMPLX_names = [ | |
| "pelvis", | |
| "left_hip", | |
| "right_hip", | |
| "spine1", | |
| "left_knee", | |
| "right_knee", | |
| "spine2", | |
| "left_ankle", | |
| "right_ankle", | |
| "spine3", | |
| "left_foot", | |
| "right_foot", | |
| "neck", | |
| "left_collar", | |
| "right_collar", | |
| "head", | |
| "left_shoulder", | |
| "right_shoulder", | |
| "left_elbow", | |
| "right_elbow", | |
| "left_wrist", | |
| "right_wrist", | |
| "jaw", | |
| "left_eye_smplx", | |
| "right_eye_smplx", | |
| "left_index1", | |
| "left_index2", | |
| "left_index3", | |
| "left_middle1", | |
| "left_middle2", | |
| "left_middle3", | |
| "left_pinky1", | |
| "left_pinky2", | |
| "left_pinky3", | |
| "left_ring1", | |
| "left_ring2", | |
| "left_ring3", | |
| "left_thumb1", | |
| "left_thumb2", | |
| "left_thumb3", | |
| "right_index1", | |
| "right_index2", | |
| "right_index3", | |
| "right_middle1", | |
| "right_middle2", | |
| "right_middle3", | |
| "right_pinky1", | |
| "right_pinky2", | |
| "right_pinky3", | |
| "right_ring1", | |
| "right_ring2", | |
| "right_ring3", | |
| "right_thumb1", | |
| "right_thumb2", | |
| "right_thumb3", | |
| "right_eye_brow1", | |
| "right_eye_brow2", | |
| "right_eye_brow3", | |
| "right_eye_brow4", | |
| "right_eye_brow5", | |
| "left_eye_brow5", | |
| "left_eye_brow4", | |
| "left_eye_brow3", | |
| "left_eye_brow2", | |
| "left_eye_brow1", | |
| "nose1", | |
| "nose2", | |
| "nose3", | |
| "nose4", | |
| "right_nose_2", | |
| "right_nose_1", | |
| "nose_middle", | |
| "left_nose_1", | |
| "left_nose_2", | |
| "right_eye1", | |
| "right_eye2", | |
| "right_eye3", | |
| "right_eye4", | |
| "right_eye5", | |
| "right_eye6", | |
| "left_eye4", | |
| "left_eye3", | |
| "left_eye2", | |
| "left_eye1", | |
| "left_eye6", | |
| "left_eye5", | |
| "right_mouth_1", | |
| "right_mouth_2", | |
| "right_mouth_3", | |
| "mouth_top", | |
| "left_mouth_3", | |
| "left_mouth_2", | |
| "left_mouth_1", | |
| "left_mouth_5", | |
| "left_mouth_4", | |
| "mouth_bottom", | |
| "right_mouth_4", | |
| "right_mouth_5", | |
| "right_lip_1", | |
| "right_lip_2", | |
| "lip_top", | |
| "left_lip_2", | |
| "left_lip_1", | |
| "left_lip_3", | |
| "lip_bottom", | |
| "right_lip_3", | |
| "right_contour_1", | |
| "right_contour_2", | |
| "right_contour_3", | |
| "right_contour_4", | |
| "right_contour_5", | |
| "right_contour_6", | |
| "right_contour_7", | |
| "right_contour_8", | |
| "contour_middle", | |
| "left_contour_8", | |
| "left_contour_7", | |
| "left_contour_6", | |
| "left_contour_5", | |
| "left_contour_4", | |
| "left_contour_3", | |
| "left_contour_2", | |
| "left_contour_1", | |
| "head_top", | |
| "left_big_toe", | |
| "left_ear", | |
| "left_eye", | |
| "left_heel", | |
| "left_index", | |
| "left_middle", | |
| "left_pinky", | |
| "left_ring", | |
| "left_small_toe", | |
| "left_thumb", | |
| "nose", | |
| "right_big_toe", | |
| "right_ear", | |
| "right_eye", | |
| "right_heel", | |
| "right_index", | |
| "right_middle", | |
| "right_pinky", | |
| "right_ring", | |
| "right_small_toe", | |
| "right_thumb", | |
| ] | |
| extra_names = [ | |
| "head_top", | |
| "left_big_toe", | |
| "left_ear", | |
| "left_eye", | |
| "left_heel", | |
| "left_index", | |
| "left_middle", | |
| "left_pinky", | |
| "left_ring", | |
| "left_small_toe", | |
| "left_thumb", | |
| "nose", | |
| "right_big_toe", | |
| "right_ear", | |
| "right_eye", | |
| "right_heel", | |
| "right_index", | |
| "right_middle", | |
| "right_pinky", | |
| "right_ring", | |
| "right_small_toe", | |
| "right_thumb", | |
| ] | |
| SMPLX_names += extra_names | |
| part_indices = {} | |
| part_indices["body"] = np.array([ | |
| 0, | |
| 1, | |
| 2, | |
| 3, | |
| 4, | |
| 5, | |
| 6, | |
| 7, | |
| 8, | |
| 9, | |
| 10, | |
| 11, | |
| 12, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| 19, | |
| 20, | |
| 21, | |
| 22, | |
| 23, | |
| 24, | |
| 123, | |
| 124, | |
| 125, | |
| 126, | |
| 127, | |
| 132, | |
| 134, | |
| 135, | |
| 136, | |
| 137, | |
| 138, | |
| 143, | |
| ]) | |
| part_indices["torso"] = np.array([ | |
| 0, | |
| 1, | |
| 2, | |
| 3, | |
| 6, | |
| 9, | |
| 12, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| 19, | |
| 22, | |
| 23, | |
| 24, | |
| 55, | |
| 56, | |
| 57, | |
| 58, | |
| 59, | |
| 76, | |
| 77, | |
| 78, | |
| 79, | |
| 80, | |
| 81, | |
| 82, | |
| 83, | |
| 84, | |
| 85, | |
| 86, | |
| 87, | |
| 88, | |
| 89, | |
| 90, | |
| 91, | |
| 92, | |
| 93, | |
| 94, | |
| 95, | |
| 96, | |
| 97, | |
| 98, | |
| 99, | |
| 100, | |
| 101, | |
| 102, | |
| 103, | |
| 104, | |
| 105, | |
| 106, | |
| 107, | |
| 108, | |
| 109, | |
| 110, | |
| 111, | |
| 112, | |
| 113, | |
| 114, | |
| 115, | |
| 116, | |
| 117, | |
| 118, | |
| 119, | |
| 120, | |
| 121, | |
| 122, | |
| 123, | |
| 124, | |
| 125, | |
| 126, | |
| 127, | |
| 128, | |
| 129, | |
| 130, | |
| 131, | |
| 132, | |
| 133, | |
| 134, | |
| 135, | |
| 136, | |
| 137, | |
| 138, | |
| 139, | |
| 140, | |
| 141, | |
| 142, | |
| 143, | |
| 144, | |
| ]) | |
| part_indices["head"] = np.array([ | |
| 12, | |
| 15, | |
| 22, | |
| 23, | |
| 24, | |
| 55, | |
| 56, | |
| 57, | |
| 58, | |
| 59, | |
| 60, | |
| 61, | |
| 62, | |
| 63, | |
| 64, | |
| 65, | |
| 66, | |
| 67, | |
| 68, | |
| 69, | |
| 70, | |
| 71, | |
| 72, | |
| 73, | |
| 74, | |
| 75, | |
| 76, | |
| 77, | |
| 78, | |
| 79, | |
| 80, | |
| 81, | |
| 82, | |
| 83, | |
| 84, | |
| 85, | |
| 86, | |
| 87, | |
| 88, | |
| 89, | |
| 90, | |
| 91, | |
| 92, | |
| 93, | |
| 94, | |
| 95, | |
| 96, | |
| 97, | |
| 98, | |
| 99, | |
| 100, | |
| 101, | |
| 102, | |
| 103, | |
| 104, | |
| 105, | |
| 106, | |
| 107, | |
| 108, | |
| 109, | |
| 110, | |
| 111, | |
| 112, | |
| 113, | |
| 114, | |
| 115, | |
| 116, | |
| 117, | |
| 118, | |
| 119, | |
| 120, | |
| 121, | |
| 122, | |
| 123, | |
| 125, | |
| 126, | |
| 134, | |
| 136, | |
| 137, | |
| ]) | |
| part_indices["face"] = np.array([ | |
| 55, | |
| 56, | |
| 57, | |
| 58, | |
| 59, | |
| 60, | |
| 61, | |
| 62, | |
| 63, | |
| 64, | |
| 65, | |
| 66, | |
| 67, | |
| 68, | |
| 69, | |
| 70, | |
| 71, | |
| 72, | |
| 73, | |
| 74, | |
| 75, | |
| 76, | |
| 77, | |
| 78, | |
| 79, | |
| 80, | |
| 81, | |
| 82, | |
| 83, | |
| 84, | |
| 85, | |
| 86, | |
| 87, | |
| 88, | |
| 89, | |
| 90, | |
| 91, | |
| 92, | |
| 93, | |
| 94, | |
| 95, | |
| 96, | |
| 97, | |
| 98, | |
| 99, | |
| 100, | |
| 101, | |
| 102, | |
| 103, | |
| 104, | |
| 105, | |
| 106, | |
| 107, | |
| 108, | |
| 109, | |
| 110, | |
| 111, | |
| 112, | |
| 113, | |
| 114, | |
| 115, | |
| 116, | |
| 117, | |
| 118, | |
| 119, | |
| 120, | |
| 121, | |
| 122, | |
| ]) | |
| part_indices["upper"] = np.array([ | |
| 12, | |
| 13, | |
| 14, | |
| 55, | |
| 56, | |
| 57, | |
| 58, | |
| 59, | |
| 60, | |
| 61, | |
| 62, | |
| 63, | |
| 64, | |
| 65, | |
| 66, | |
| 67, | |
| 68, | |
| 69, | |
| 70, | |
| 71, | |
| 72, | |
| 73, | |
| 74, | |
| 75, | |
| 76, | |
| 77, | |
| 78, | |
| 79, | |
| 80, | |
| 81, | |
| 82, | |
| 83, | |
| 84, | |
| 85, | |
| 86, | |
| 87, | |
| 88, | |
| 89, | |
| 90, | |
| 91, | |
| 92, | |
| 93, | |
| 94, | |
| 95, | |
| 96, | |
| 97, | |
| 98, | |
| 99, | |
| 100, | |
| 101, | |
| 102, | |
| 103, | |
| 104, | |
| 105, | |
| 106, | |
| 107, | |
| 108, | |
| 109, | |
| 110, | |
| 111, | |
| 112, | |
| 113, | |
| 114, | |
| 115, | |
| 116, | |
| 117, | |
| 118, | |
| 119, | |
| 120, | |
| 121, | |
| 122, | |
| ]) | |
| part_indices["hand"] = np.array([ | |
| 20, | |
| 21, | |
| 25, | |
| 26, | |
| 27, | |
| 28, | |
| 29, | |
| 30, | |
| 31, | |
| 32, | |
| 33, | |
| 34, | |
| 35, | |
| 36, | |
| 37, | |
| 38, | |
| 39, | |
| 40, | |
| 41, | |
| 42, | |
| 43, | |
| 44, | |
| 45, | |
| 46, | |
| 47, | |
| 48, | |
| 49, | |
| 50, | |
| 51, | |
| 52, | |
| 53, | |
| 54, | |
| 128, | |
| 129, | |
| 130, | |
| 131, | |
| 133, | |
| 139, | |
| 140, | |
| 141, | |
| 142, | |
| 144, | |
| ]) | |
| part_indices["left_hand"] = np.array([ | |
| 20, | |
| 25, | |
| 26, | |
| 27, | |
| 28, | |
| 29, | |
| 30, | |
| 31, | |
| 32, | |
| 33, | |
| 34, | |
| 35, | |
| 36, | |
| 37, | |
| 38, | |
| 39, | |
| 128, | |
| 129, | |
| 130, | |
| 131, | |
| 133, | |
| ]) | |
| part_indices["right_hand"] = np.array([ | |
| 21, | |
| 40, | |
| 41, | |
| 42, | |
| 43, | |
| 44, | |
| 45, | |
| 46, | |
| 47, | |
| 48, | |
| 49, | |
| 50, | |
| 51, | |
| 52, | |
| 53, | |
| 54, | |
| 139, | |
| 140, | |
| 141, | |
| 142, | |
| 144, | |
| ]) | |
| # kinematic tree | |
| head_kin_chain = [15, 12, 9, 6, 3, 0] | |
| # --smplx joints | |
| # 00 - Global | |
| # 01 - L_Thigh | |
| # 02 - R_Thigh | |
| # 03 - Spine | |
| # 04 - L_Calf | |
| # 05 - R_Calf | |
| # 06 - Spine1 | |
| # 07 - L_Foot | |
| # 08 - R_Foot | |
| # 09 - Spine2 | |
| # 10 - L_Toes | |
| # 11 - R_Toes | |
| # 12 - Neck | |
| # 13 - L_Shoulder | |
| # 14 - R_Shoulder | |
| # 15 - Head | |
| # 16 - L_UpperArm | |
| # 17 - R_UpperArm | |
| # 18 - L_ForeArm | |
| # 19 - R_ForeArm | |
| # 20 - L_Hand | |
| # 21 - R_Hand | |
| # 22 - Jaw | |
| # 23 - L_Eye | |
| # 24 - R_Eye | |
| class SMPLX(nn.Module): | |
| """ | |
| Given smplx parameters, this class generates a differentiable SMPLX function | |
| which outputs a mesh and 3D joints | |
| """ | |
| def __init__(self, config): | |
| super(SMPLX, self).__init__() | |
| # print("creating the SMPLX Decoder") | |
| ss = np.load(config.smplx_model_path, allow_pickle=True) | |
| smplx_model = Struct(**ss) | |
| self.dtype = torch.float32 | |
| self.register_buffer( | |
| "faces_tensor", | |
| to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long), | |
| ) | |
| # The vertices of the template model | |
| self.register_buffer( | |
| "v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype) | |
| ) | |
| # The shape components and expression | |
| # expression space is the same as FLAME | |
| shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype) | |
| shapedirs = torch.cat( | |
| [ | |
| shapedirs[:, :, :config.n_shape], | |
| shapedirs[:, :, 300:300 + config.n_exp], | |
| ], | |
| 2, | |
| ) | |
| self.register_buffer("shapedirs", shapedirs) | |
| # The pose components | |
| num_pose_basis = smplx_model.posedirs.shape[-1] | |
| posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T | |
| self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype)) | |
| self.register_buffer( | |
| "J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype) | |
| ) | |
| parents = to_tensor(to_np(smplx_model.kintree_table[0])).long() | |
| parents[0] = -1 | |
| self.register_buffer("parents", parents) | |
| self.register_buffer("lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) | |
| # for face keypoints | |
| self.register_buffer( | |
| "lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long) | |
| ) | |
| self.register_buffer( | |
| "lmk_bary_coords", | |
| torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype), | |
| ) | |
| self.register_buffer( | |
| "dynamic_lmk_faces_idx", | |
| torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long), | |
| ) | |
| self.register_buffer( | |
| "dynamic_lmk_bary_coords", | |
| torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype), | |
| ) | |
| # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks | |
| self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long)) | |
| # -- initialize parameters | |
| # shape and expression | |
| self.register_buffer( | |
| "shape_params", | |
| nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False), | |
| ) | |
| self.register_buffer( | |
| "expression_params", | |
| nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False), | |
| ) | |
| # pose: represented as rotation matrx [number of joints, 3, 3] | |
| self.register_buffer( | |
| "global_pose", | |
| nn.Parameter( | |
| torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), | |
| requires_grad=False, | |
| ), | |
| ) | |
| self.register_buffer( | |
| "head_pose", | |
| nn.Parameter( | |
| torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), | |
| requires_grad=False, | |
| ), | |
| ) | |
| self.register_buffer( | |
| "neck_pose", | |
| nn.Parameter( | |
| torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), | |
| requires_grad=False, | |
| ), | |
| ) | |
| self.register_buffer( | |
| "jaw_pose", | |
| nn.Parameter( | |
| torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), | |
| requires_grad=False, | |
| ), | |
| ) | |
| self.register_buffer( | |
| "eye_pose", | |
| nn.Parameter( | |
| torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1), | |
| requires_grad=False, | |
| ), | |
| ) | |
| self.register_buffer( | |
| "body_pose", | |
| nn.Parameter( | |
| torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1), | |
| requires_grad=False, | |
| ), | |
| ) | |
| self.register_buffer( | |
| "left_hand_pose", | |
| nn.Parameter( | |
| torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), | |
| requires_grad=False, | |
| ), | |
| ) | |
| self.register_buffer( | |
| "right_hand_pose", | |
| nn.Parameter( | |
| torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), | |
| requires_grad=False, | |
| ), | |
| ) | |
| if config.extra_joint_path: | |
| self.extra_joint_selector = JointsFromVerticesSelector(fname=config.extra_joint_path) | |
| self.use_joint_regressor = True | |
| self.keypoint_names = SMPLX_names | |
| if self.use_joint_regressor: | |
| with open(config.j14_regressor_path, "rb") as f: | |
| j14_regressor = pickle.load(f, encoding="latin1") | |
| source = [] | |
| target = [] | |
| for idx, name in enumerate(self.keypoint_names): | |
| if name in J14_NAMES: | |
| source.append(idx) | |
| target.append(J14_NAMES.index(name)) | |
| source = np.asarray(source) | |
| target = np.asarray(target) | |
| self.register_buffer("source_idxs", torch.from_numpy(source)) | |
| self.register_buffer("target_idxs", torch.from_numpy(target)) | |
| self.register_buffer( | |
| "extra_joint_regressor", | |
| torch.from_numpy(j14_regressor).to(torch.float32) | |
| ) | |
| self.part_indices = part_indices | |
| def forward( | |
| self, | |
| shape_params=None, | |
| expression_params=None, | |
| global_pose=None, | |
| body_pose=None, | |
| jaw_pose=None, | |
| eye_pose=None, | |
| left_hand_pose=None, | |
| right_hand_pose=None, | |
| ): | |
| """ | |
| Args: | |
| shape_params: [N, number of shape parameters] | |
| expression_params: [N, number of expression parameters] | |
| global_pose: pelvis pose, [N, 1, 3, 3] | |
| body_pose: [N, 21, 3, 3] | |
| jaw_pose: [N, 1, 3, 3] | |
| eye_pose: [N, 2, 3, 3] | |
| left_hand_pose: [N, 15, 3, 3] | |
| right_hand_pose: [N, 15, 3, 3] | |
| Returns: | |
| vertices: [N, number of vertices, 3] | |
| landmarks: [N, number of landmarks (68 face keypoints), 3] | |
| joints: [N, number of smplx joints (145), 3] | |
| """ | |
| if shape_params is None: | |
| batch_size = global_pose.shape[0] | |
| shape_params = self.shape_params.expand(batch_size, -1) | |
| else: | |
| batch_size = shape_params.shape[0] | |
| if expression_params is None: | |
| expression_params = self.expression_params.expand(batch_size, -1) | |
| if global_pose is None: | |
| global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| if body_pose is None: | |
| body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| if jaw_pose is None: | |
| jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| if eye_pose is None: | |
| eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| if left_hand_pose is None: | |
| left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| if right_hand_pose is None: | |
| right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| shape_components = torch.cat([shape_params, expression_params], dim=1) | |
| full_pose = torch.cat( | |
| [ | |
| global_pose, | |
| body_pose, | |
| jaw_pose, | |
| eye_pose, | |
| left_hand_pose, | |
| right_hand_pose, | |
| ], | |
| dim=1, | |
| ) | |
| template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) | |
| # smplx | |
| vertices, joints = lbs( | |
| shape_components, | |
| full_pose, | |
| template_vertices, | |
| self.shapedirs, | |
| self.posedirs, | |
| self.J_regressor, | |
| self.parents, | |
| self.lbs_weights, | |
| dtype=self.dtype, | |
| pose2rot=False, | |
| ) | |
| # face dynamic landmarks | |
| lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) | |
| lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) | |
| dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords( | |
| vertices, | |
| full_pose, | |
| self.dynamic_lmk_faces_idx, | |
| self.dynamic_lmk_bary_coords, | |
| self.head_kin_chain, | |
| ) | |
| lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) | |
| lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1) | |
| landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) | |
| final_joint_set = [joints, landmarks] | |
| if hasattr(self, "extra_joint_selector"): | |
| # Add any extra joints that might be needed | |
| extra_joints = self.extra_joint_selector(vertices, self.faces_tensor) | |
| final_joint_set.append(extra_joints) | |
| # Create the final joint set | |
| joints = torch.cat(final_joint_set, dim=1) | |
| # if self.use_joint_regressor: | |
| # reg_joints = torch.einsum("ji,bik->bjk", | |
| # self.extra_joint_regressor, vertices) | |
| # joints[:, self.source_idxs] = reg_joints[:, self.target_idxs] | |
| return vertices, landmarks, joints | |
| def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"): | |
| """change absolute pose to relative pose | |
| Basic knowledge for SMPLX kinematic tree: | |
| absolute pose = parent pose * relative pose | |
| Here, pose must be represented as rotation matrix (batch_sizexnx3x3) | |
| """ | |
| if abs_joint == "head": | |
| # Pelvis -> Spine 1, 2, 3 -> Neck -> Head | |
| kin_chain = [15, 12, 9, 6, 3, 0] | |
| elif abs_joint == "neck": | |
| # Pelvis -> Spine 1, 2, 3 -> Neck -> Head | |
| kin_chain = [12, 9, 6, 3, 0] | |
| elif abs_joint == "right_wrist": | |
| # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder | |
| # -> right elbow -> right wrist | |
| kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] | |
| elif abs_joint == "left_wrist": | |
| # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder | |
| # -> Left elbow -> Left wrist | |
| kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] | |
| else: | |
| raise NotImplementedError(f"pose_abs2rel does not support: {abs_joint}") | |
| batch_size = global_pose.shape[0] | |
| dtype = global_pose.dtype | |
| device = global_pose.device | |
| full_pose = torch.cat([global_pose, body_pose], dim=1) | |
| rel_rot_mat = ( | |
| torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1) | |
| ) | |
| for idx in kin_chain[1:]: | |
| rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat) | |
| # This contains the absolute pose of the parent | |
| abs_parent_pose = rel_rot_mat.detach() | |
| # Let's assume that in the input this specific joint is predicted as an absolute value | |
| abs_joint_pose = body_pose[:, kin_chain[0] - 1] | |
| # abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head | |
| rel_joint_pose = torch.matmul( | |
| abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2), | |
| abs_joint_pose.reshape(-1, 3, 3), | |
| ) | |
| # Replace the new relative pose | |
| body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose | |
| return body_pose | |
| def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"): | |
| """change relative pose to absolute pose | |
| Basic knowledge for SMPLX kinematic tree: | |
| absolute pose = parent pose * relative pose | |
| Here, pose must be represented as rotation matrix (batch_sizexnx3x3) | |
| """ | |
| full_pose = torch.cat([global_pose, body_pose], dim=1) | |
| if abs_joint == "head": | |
| # Pelvis -> Spine 1, 2, 3 -> Neck -> Head | |
| kin_chain = [15, 12, 9, 6, 3, 0] | |
| elif abs_joint == "neck": | |
| # Pelvis -> Spine 1, 2, 3 -> Neck -> Head | |
| kin_chain = [12, 9, 6, 3, 0] | |
| elif abs_joint == "right_wrist": | |
| # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder | |
| # -> right elbow -> right wrist | |
| kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] | |
| elif abs_joint == "left_wrist": | |
| # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder | |
| # -> Left elbow -> Left wrist | |
| kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] | |
| else: | |
| raise NotImplementedError(f"pose_rel2abs does not support: {abs_joint}") | |
| rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0) | |
| for idx in kin_chain: | |
| rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat) | |
| abs_pose = rel_rot_mat[:, None, :, :] | |
| return abs_pose | |