Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # Using this computer program means that you agree to the terms | |
| # in the LICENSE file included with this software distribution. | |
| # Any use not explicitly granted by the LICENSE is prohibited. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # For comments or questions, please email us at [email protected] | |
| # For commercial licensing contact, please contact [email protected] | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from skimage.io import imread | |
| from .models.encoders import MLP, HRNEncoder, ResnetEncoder | |
| from .models.moderators import TempSoftmaxFusion | |
| from .models.SMPLX import SMPLX | |
| from .utils import rotation_converter as converter | |
| from .utils import tensor_cropper, util | |
| from .utils.config import cfg | |
| class PIXIE(object): | |
| def __init__(self, config=None, device="cuda:0"): | |
| if config is None: | |
| self.cfg = cfg | |
| else: | |
| self.cfg = config | |
| self.device = device | |
| # parameters setting | |
| self.param_list_dict = {} | |
| for lst in self.cfg.params.keys(): | |
| param_list = cfg.params.get(lst) | |
| self.param_list_dict[lst] = {i: cfg.model.get("n_" + i) for i in param_list} | |
| # Build the models | |
| self._create_model() | |
| # Set up the cropping modules used to generate face/hand crops from the body predictions | |
| self._setup_cropper() | |
| def forward(self, data): | |
| # encode + decode | |
| param_dict = self.encode( | |
| {"body": {"image": data}}, | |
| threthold=True, | |
| keep_local=True, | |
| copy_and_paste=False, | |
| ) | |
| opdict = self.decode(param_dict["body"], param_type="body") | |
| return opdict | |
| def _setup_cropper(self): | |
| self.Cropper = {} | |
| for crop_part in ["head", "hand"]: | |
| data_cfg = self.cfg.dataset[crop_part] | |
| scale_size = (data_cfg.scale_min + data_cfg.scale_max) * 0.5 | |
| self.Cropper[crop_part] = tensor_cropper.Cropper( | |
| crop_size=data_cfg.image_size, | |
| scale=[scale_size, scale_size], | |
| trans_scale=0, | |
| ) | |
| def _create_model(self): | |
| self.model_dict = {} | |
| # Build all image encoders | |
| # Hand encoder only works for right hand, for left hand, flip inputs and flip the results back | |
| self.Encoder = {} | |
| for key in self.cfg.network.encoder.keys(): | |
| if self.cfg.network.encoder.get(key).type == "resnet50": | |
| self.Encoder[key] = ResnetEncoder().to(self.device) | |
| elif self.cfg.network.encoder.get(key).type == "hrnet": | |
| self.Encoder[key] = HRNEncoder().to(self.device) | |
| self.model_dict[f"Encoder_{key}"] = self.Encoder[key].state_dict() | |
| # Build the parameter regressors | |
| self.Regressor = {} | |
| for key in self.cfg.network.regressor.keys(): | |
| n_output = sum(self.param_list_dict[f"{key}_list"].values()) | |
| channels = ([2048] + self.cfg.network.regressor.get(key).channels + [n_output]) | |
| if self.cfg.network.regressor.get(key).type == "mlp": | |
| self.Regressor[key] = MLP(channels=channels).to(self.device) | |
| self.model_dict[f"Regressor_{key}"] = self.Regressor[key].state_dict() | |
| # Build the extractors | |
| # to extract separate head/left hand/right hand feature from body feature | |
| self.Extractor = {} | |
| for key in self.cfg.network.extractor.keys(): | |
| channels = [2048] + self.cfg.network.extractor.get(key).channels + [2048] | |
| if self.cfg.network.extractor.get(key).type == "mlp": | |
| self.Extractor[key] = MLP(channels=channels).to(self.device) | |
| self.model_dict[f"Extractor_{key}"] = self.Extractor[key].state_dict() | |
| # Build the moderators | |
| self.Moderator = {} | |
| for key in self.cfg.network.moderator.keys(): | |
| share_part = key.split("_")[0] | |
| detach_inputs = self.cfg.network.moderator.get(key).detach_inputs | |
| detach_feature = self.cfg.network.moderator.get(key).detach_feature | |
| channels = [2048 * 2] + self.cfg.network.moderator.get(key).channels + [2] | |
| self.Moderator[key] = TempSoftmaxFusion( | |
| detach_inputs=detach_inputs, | |
| detach_feature=detach_feature, | |
| channels=channels, | |
| ).to(self.device) | |
| self.model_dict[f"Moderator_{key}"] = self.Moderator[key].state_dict() | |
| # Build the SMPL-X body model, which we also use to represent faces and | |
| # hands, using the relevant parts only | |
| self.smplx = SMPLX(self.cfg.model).to(self.device) | |
| self.part_indices = self.smplx.part_indices | |
| # -- resume model | |
| model_path = self.cfg.pretrained_modelpath | |
| if os.path.exists(model_path): | |
| checkpoint = torch.load(model_path) | |
| for key in self.model_dict.keys(): | |
| util.copy_state_dict(self.model_dict[key], checkpoint[key]) | |
| else: | |
| print(f"pixie trained model path: {model_path} does not exist!") | |
| exit() | |
| # eval mode | |
| for module in [self.Encoder, self.Regressor, self.Moderator, self.Extractor]: | |
| for net in module.values(): | |
| net.eval() | |
| def decompose_code(self, code, num_dict): | |
| """Convert a flattened parameter vector to a dictionary of parameters""" | |
| code_dict = {} | |
| start = 0 | |
| for key in num_dict: | |
| end = start + int(num_dict[key]) | |
| code_dict[key] = code[:, start:end] | |
| start = end | |
| return code_dict | |
| def part_from_body(self, image, part_key, points_dict, crop_joints=None): | |
| """crop part(head/left_hand/right_hand) out from body data, joints also change accordingly""" | |
| assert part_key in ["head", "left_hand", "right_hand"] | |
| assert "smplx_kpt" in points_dict.keys() | |
| if part_key == "head": | |
| # use face 68 kpts for cropping head image | |
| indices_key = "face" | |
| elif part_key == "left_hand": | |
| indices_key = "left_hand" | |
| elif part_key == "right_hand": | |
| indices_key = "right_hand" | |
| # get points for cropping | |
| part_indices = self.part_indices[indices_key] | |
| if crop_joints is not None: | |
| points_for_crop = crop_joints[:, part_indices] | |
| else: | |
| points_for_crop = points_dict["smplx_kpt"][:, part_indices] | |
| # crop | |
| cropper_key = "hand" if "hand" in part_key else part_key | |
| points_scale = image.shape[-2:] | |
| cropped_image, tform = self.Cropper[cropper_key].crop(image, points_for_crop, points_scale) | |
| # transform points(must be normalized to [-1.1]) accordingly | |
| cropped_points_dict = {} | |
| for points_key in points_dict.keys(): | |
| points = points_dict[points_key] | |
| cropped_points = self.Cropper[cropper_key].transform_points( | |
| points, tform, points_scale, normalize=True | |
| ) | |
| cropped_points_dict[points_key] = cropped_points | |
| return cropped_image, cropped_points_dict | |
| def encode( | |
| self, | |
| data, | |
| threthold=True, | |
| keep_local=True, | |
| copy_and_paste=False, | |
| body_only=False, | |
| ): | |
| """Encode images to smplx parameters | |
| Args: | |
| data: dict | |
| key: image_type (body/head/hand) | |
| value: | |
| image: [bz, 3, 224, 224], range [0,1] | |
| image_hd(needed if key==body): a high res version of image, only for cropping parts from body image | |
| head_image: optinal, well-cropped head from body image | |
| left_hand_image: optinal, well-cropped left hand from body image | |
| right_hand_image: optinal, well-cropped right hand from body image | |
| Returns: | |
| param_dict: dict | |
| key: image_type (body/head/hand) | |
| value: param_dict | |
| """ | |
| for key in data.keys(): | |
| assert key in ["body", "head", "hand"] | |
| feature = {} | |
| param_dict = {} | |
| # Encode features | |
| for key in data.keys(): | |
| part = key | |
| # encode feature | |
| feature[key] = {} | |
| feature[key][part] = self.Encoder[part](data[key]["image"]) | |
| # for head/hand image | |
| if key == "head" or key == "hand": | |
| # predict head/hand-only parameters from part feature | |
| part_dict = self.decompose_code( | |
| self.Regressor[part](feature[key][part]), | |
| self.param_list_dict[f"{part}_list"], | |
| ) | |
| # if input is part data, skip feature fusion: share feature is the same as part feature | |
| # then predict share parameters | |
| feature[key][f"{key}_share"] = feature[key][key] | |
| share_dict = self.decompose_code( | |
| self.Regressor[f"{part}_share"](feature[key][f"{part}_share"]), | |
| self.param_list_dict[f"{part}_share_list"], | |
| ) | |
| # compose parameters | |
| param_dict[key] = {**share_dict, **part_dict} | |
| # for body image | |
| if key == "body": | |
| fusion_weight = {} | |
| f_body = feature["body"]["body"] | |
| # extract part feature | |
| for part_name in ["head", "left_hand", "right_hand"]: | |
| feature["body"][f"{part_name}_share"] = self.Extractor[f"{part_name}_share"]( | |
| f_body | |
| ) | |
| # -- check if part crops are given, if not, crop parts by coarse body estimation | |
| if ( | |
| "head_image" not in data[key].keys() or | |
| "left_hand_image" not in data[key].keys() or | |
| "right_hand_image" not in data[key].keys() | |
| ): | |
| # - run without fusion to get coarse estimation, for cropping parts | |
| # body only | |
| body_dict = self.decompose_code( | |
| self.Regressor[part](feature[key][part]), | |
| self.param_list_dict[part + "_list"], | |
| ) | |
| # head share | |
| head_share_dict = self.decompose_code( | |
| self.Regressor["head" + "_share"](feature[key]["head" + "_share"]), | |
| self.param_list_dict["head" + "_share_list"], | |
| ) | |
| # right hand share | |
| right_hand_share_dict = self.decompose_code( | |
| self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]), | |
| self.param_list_dict["hand" + "_share_list"], | |
| ) | |
| # left hand share | |
| left_hand_share_dict = self.decompose_code( | |
| self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]), | |
| self.param_list_dict["hand" + "_share_list"], | |
| ) | |
| # change the dict name from right to left | |
| left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop( | |
| "right_hand_pose" | |
| ) | |
| left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop( | |
| "right_wrist_pose" | |
| ) | |
| param_dict[key] = { | |
| **body_dict, | |
| **head_share_dict, | |
| **left_hand_share_dict, | |
| **right_hand_share_dict, | |
| } | |
| if body_only: | |
| param_dict["moderator_weight"] = None | |
| return param_dict | |
| prediction_body_only = self.decode(param_dict[key], param_type="body") | |
| # crop | |
| for part_name in ["head", "left_hand", "right_hand"]: | |
| part = part_name.split("_")[-1] | |
| points_dict = { | |
| "smplx_kpt": prediction_body_only["smplx_kpt"], | |
| "trans_verts": prediction_body_only["transformed_vertices"], | |
| } | |
| image_hd = torchvision.transforms.Resize(1024)(data["body"]["image"]) | |
| cropped_image, cropped_joints_dict = self.part_from_body( | |
| image_hd, part_name, points_dict | |
| ) | |
| data[key][part_name + "_image"] = cropped_image | |
| # -- encode features from part crops, then fuse feature using the weight from moderator | |
| for part_name in ["head", "left_hand", "right_hand"]: | |
| part = part_name.split("_")[-1] | |
| cropped_image = data[key][part_name + "_image"] | |
| # if left hand, flip it as if it is right hand | |
| if part_name == "left_hand": | |
| cropped_image = torch.flip(cropped_image, dims=(-1, )) | |
| # run part regressor | |
| f_part = self.Encoder[part](cropped_image) | |
| part_dict = self.decompose_code( | |
| self.Regressor[part](f_part), | |
| self.param_list_dict[f"{part}_list"], | |
| ) | |
| part_share_dict = self.decompose_code( | |
| self.Regressor[f"{part}_share"](f_part), | |
| self.param_list_dict[f"{part}_share_list"], | |
| ) | |
| param_dict["body_" + part_name] = {**part_dict, **part_share_dict} | |
| # moderator to assign weight, then integrate features | |
| f_body_out, f_part_out, f_weight = self.Moderator[f"{part}_share"]( | |
| feature["body"][f"{part_name}_share"], f_part, work=True | |
| ) | |
| if copy_and_paste: | |
| # copy and paste strategy always trusts the results from part | |
| feature["body"][f"{part_name}_share"] = f_part | |
| elif threthold and part == "hand": | |
| # for hand, if part weight > 0.7 (very confident, then fully trust part) | |
| part_w = f_weight[:, [1]] | |
| part_w[part_w > 0.7] = 1.0 | |
| f_body_out = ( | |
| feature["body"][f"{part_name}_share"] * (1.0 - part_w) + f_part * part_w | |
| ) | |
| feature["body"][f"{part_name}_share"] = f_body_out | |
| else: | |
| feature["body"][f"{part_name}_share"] = f_body_out | |
| fusion_weight[part_name] = f_weight | |
| # save weights from moderator, that can be further used for optimization/running specific tasks on parts | |
| param_dict["moderator_weight"] = fusion_weight | |
| # -- predict parameters from fused body feature | |
| # head share | |
| head_share_dict = self.decompose_code( | |
| self.Regressor["head" + "_share"](feature[key]["head" + "_share"]), | |
| self.param_list_dict["head" + "_share_list"], | |
| ) | |
| # right hand share | |
| right_hand_share_dict = self.decompose_code( | |
| self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]), | |
| self.param_list_dict["hand" + "_share_list"], | |
| ) | |
| # left hand share | |
| left_hand_share_dict = self.decompose_code( | |
| self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]), | |
| self.param_list_dict["hand" + "_share_list"], | |
| ) | |
| # change the dict name from right to left | |
| left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop("right_hand_pose") | |
| left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop( | |
| "right_wrist_pose" | |
| ) | |
| param_dict["body"] = { | |
| **body_dict, | |
| **head_share_dict, | |
| **left_hand_share_dict, | |
| **right_hand_share_dict, | |
| } | |
| # copy tex param from head param dict to body param dict | |
| param_dict["body"]["tex"] = param_dict["body_head"]["tex"] | |
| param_dict["body"]["light"] = param_dict["body_head"]["light"] | |
| if keep_local: | |
| # for local change that will not affect whole body and produce unnatral pose, trust part | |
| param_dict[key]["exp"] = param_dict["body_head"]["exp"] | |
| param_dict[key]["right_hand_pose"] = param_dict["body_right_hand"][ | |
| "right_hand_pose"] | |
| param_dict[key]["left_hand_pose"] = param_dict["body_left_hand"][ | |
| "right_hand_pose"] | |
| return param_dict | |
| def convert_pose(self, param_dict, param_type): | |
| """Convert pose parameters to rotation matrix | |
| Args: | |
| param_dict: smplx parameters | |
| param_type: should be one of body/head/hand | |
| Returns: | |
| param_dict: smplx parameters | |
| """ | |
| assert param_type in ["body", "head", "hand"] | |
| # convert pose representations: the output from network are continous repre or axis angle, | |
| # while the input pose for smplx need to be rotation matrix | |
| for key in param_dict: | |
| if "pose" in key and "jaw" not in key: | |
| param_dict[key] = converter.batch_cont2matrix(param_dict[key]) | |
| if param_type == "body" or param_type == "head": | |
| param_dict["jaw_pose"] = converter.batch_euler2matrix(param_dict["jaw_pose"] | |
| )[:, None, :, :] | |
| # complement params if it's not in given param dict | |
| if param_type == "head": | |
| batch_size = param_dict["shape"].shape[0] | |
| param_dict["abs_head_pose"] = param_dict["head_pose"].clone() | |
| param_dict["global_pose"] = param_dict["head_pose"] | |
| param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| )[:, :self.param_list_dict["body_list"]["partbody_pose"]] | |
| param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| param_dict["right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| elif param_type == "hand": | |
| batch_size = param_dict["right_hand_pose"].shape[0] | |
| param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone() | |
| dtype = param_dict["right_hand_pose"].dtype | |
| device = param_dict["right_hand_pose"].device | |
| x_180_pose = (torch.eye(3, dtype=dtype, device=device).unsqueeze(0).repeat(1, 1, 1)) | |
| x_180_pose[0, 2, 2] = -1.0 | |
| x_180_pose[0, 1, 1] = -1.0 | |
| param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| param_dict["shape"] = self.smplx.shape_params.expand(batch_size, -1) | |
| param_dict["exp"] = self.smplx.expression_params.expand(batch_size, -1) | |
| param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| )[:, :self.param_list_dict["body_list"]["partbody_pose"]] | |
| param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand( | |
| batch_size, -1, -1, -1 | |
| ) | |
| elif param_type == "body": | |
| # the predcition from the head and hand share regressor is always absolute pose | |
| batch_size = param_dict["shape"].shape[0] | |
| param_dict["abs_head_pose"] = param_dict["head_pose"].clone() | |
| param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone() | |
| param_dict["abs_left_wrist_pose"] = param_dict["left_wrist_pose"].clone() | |
| # the body-hand share regressor is working for right hand | |
| # so we assume body network get the flipped feature for the left hand. then get the parameters | |
| # then we need to flip it back to left, which matches the input left hand | |
| param_dict["left_wrist_pose"] = util.flip_pose(param_dict["left_wrist_pose"]) | |
| param_dict["left_hand_pose"] = util.flip_pose(param_dict["left_hand_pose"]) | |
| else: | |
| exit() | |
| return param_dict | |
| def decode(self, param_dict, param_type): | |
| """Decode model parameters to smplx vertices & joints & texture | |
| Args: | |
| param_dict: smplx parameters | |
| param_type: should be one of body/head/hand | |
| Returns: | |
| predictions: smplx predictions | |
| """ | |
| if "jaw_pose" in param_dict.keys() and len(param_dict["jaw_pose"].shape) == 2: | |
| self.convert_pose(param_dict, param_type) | |
| elif param_dict["right_wrist_pose"].shape[-1] == 6: | |
| self.convert_pose(param_dict, param_type) | |
| # concatenate body pose | |
| partbody_pose = param_dict["partbody_pose"] | |
| param_dict["body_pose"] = torch.cat( | |
| [ | |
| partbody_pose[:, :11], | |
| param_dict["neck_pose"], | |
| partbody_pose[:, 11:11 + 2], | |
| param_dict["head_pose"], | |
| partbody_pose[:, 13:13 + 4], | |
| param_dict["left_wrist_pose"], | |
| param_dict["right_wrist_pose"], | |
| ], | |
| dim=1, | |
| ) | |
| # change absolute head&hand pose to relative pose according to rest body pose | |
| if param_type == "head" or param_type == "body": | |
| param_dict["body_pose"] = self.smplx.pose_abs2rel( | |
| param_dict["global_pose"], param_dict["body_pose"], abs_joint="head" | |
| ) | |
| if param_type == "hand" or param_type == "body": | |
| param_dict["body_pose"] = self.smplx.pose_abs2rel( | |
| param_dict["global_pose"], | |
| param_dict["body_pose"], | |
| abs_joint="left_wrist", | |
| ) | |
| param_dict["body_pose"] = self.smplx.pose_abs2rel( | |
| param_dict["global_pose"], | |
| param_dict["body_pose"], | |
| abs_joint="right_wrist", | |
| ) | |
| if self.cfg.model.check_pose: | |
| # check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose) | |
| # xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left) | |
| for pose_ind in [14]: # head [15-1, 20-1, 21-1]: | |
| curr_pose = param_dict["body_pose"][:, pose_ind] | |
| euler_pose = converter._compute_euler_from_matrix(curr_pose) | |
| for i, max_angle in enumerate([20, 70, 10]): | |
| euler_pose_curr = euler_pose[:, i] | |
| euler_pose_curr[euler_pose_curr != torch.clamp( | |
| euler_pose_curr, | |
| min=-max_angle * np.pi / 180, | |
| max=max_angle * np.pi / 180, | |
| )] = 0.0 | |
| param_dict["body_pose"][:, pose_ind] = converter.batch_euler2matrix(euler_pose) | |
| # SMPLX | |
| verts, landmarks, joints = self.smplx( | |
| shape_params=param_dict["shape"], | |
| expression_params=param_dict["exp"], | |
| global_pose=param_dict["global_pose"], | |
| body_pose=param_dict["body_pose"], | |
| jaw_pose=param_dict["jaw_pose"], | |
| left_hand_pose=param_dict["left_hand_pose"], | |
| right_hand_pose=param_dict["right_hand_pose"], | |
| ) | |
| smplx_kpt3d = joints.clone() | |
| # projection | |
| cam = param_dict[param_type + "_cam"] | |
| trans_verts = util.batch_orth_proj(verts, cam) | |
| predicted_landmarks = util.batch_orth_proj(landmarks, cam)[:, :, :2] | |
| predicted_joints = util.batch_orth_proj(joints, cam)[:, :, :2] | |
| prediction = { | |
| "vertices": verts, | |
| "transformed_vertices": trans_verts, | |
| "face_kpt": predicted_landmarks, | |
| "smplx_kpt": predicted_joints, | |
| "smplx_kpt3d": smplx_kpt3d, | |
| "joints": joints, | |
| "cam": param_dict[param_type + "_cam"], | |
| } | |
| # change the order of face keypoints, to be the same as "standard" 68 keypoints | |
| prediction["face_kpt"] = torch.cat([ | |
| prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17] | |
| ], | |
| dim=1) | |
| prediction.update(param_dict) | |
| return prediction | |
| def decode_Tpose(self, param_dict): | |
| """return body mesh in T pose, support body and head param dict only""" | |
| verts, _, _ = self.smplx( | |
| shape_params=param_dict["shape"], | |
| expression_params=param_dict["exp"], | |
| jaw_pose=param_dict["jaw_pose"], | |
| ) | |
| return verts | |