Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: [email protected] | |
| import os | |
| import lib.smplx as smplx | |
| from lib.pymaf.utils.geometry import rotation_matrix_to_angle_axis, batch_rodrigues | |
| from lib.pymaf.utils.imutils import process_image | |
| from lib.pymaf.core import path_config | |
| from lib.pymaf.models import pymaf_net | |
| from lib.common.config import cfg | |
| from lib.common.render import Render | |
| from lib.dataset.body_model import TetraSMPLModel | |
| from lib.dataset.mesh_util import get_visibility, SMPLX | |
| import os.path as osp | |
| import torch | |
| import numpy as np | |
| import random | |
| import human_det | |
| from termcolor import colored | |
| from PIL import ImageFile | |
| from huggingface_hub import cached_download | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| class TestDataset(): | |
| def __init__(self, cfg, device): | |
| random.seed(1993) | |
| self.image_path = cfg['image_path'] | |
| self.seg_dir = cfg['seg_dir'] | |
| self.has_det = cfg['has_det'] | |
| self.hps_type = cfg['hps_type'] | |
| self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx' | |
| self.smpl_gender = 'neutral' | |
| self.device = device | |
| if self.has_det: | |
| self.det = human_det.Detection() | |
| else: | |
| self.det = None | |
| self.subject_list = [self.image_path] | |
| # smpl related | |
| self.smpl_data = SMPLX() | |
| self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create( | |
| model_path=self.smpl_data.model_dir, | |
| gender=smpl_gender, | |
| model_type=smpl_type, | |
| ext='npz') | |
| # Load SMPL model | |
| self.smpl_model = self.get_smpl_model( | |
| self.smpl_type, self.smpl_gender).to(self.device) | |
| self.faces = self.smpl_model.faces | |
| self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS, | |
| pretrained=True).to(self.device) | |
| self.hps.load_state_dict(torch.load( | |
| path_config.CHECKPOINT_FILE)['model'], | |
| strict=True) | |
| self.hps.eval() | |
| print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green")) | |
| self.render = Render(size=512, device=device) | |
| def __len__(self): | |
| return len(self.subject_list) | |
| def compute_vis_cmap(self, smpl_verts, smpl_faces): | |
| (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1) | |
| smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long()) | |
| if self.smpl_type == 'smpl': | |
| smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0])) | |
| else: | |
| smplx_ind = np.arange(smpl_vis.shape[0]) | |
| smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind) | |
| return { | |
| 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device), | |
| 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device), | |
| 'smpl_verts': smpl_verts.unsqueeze(0) | |
| } | |
| def compute_voxel_verts(self, body_pose, global_orient, betas, trans, | |
| scale): | |
| smpl_path = cached_download(osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl"), use_auth_token=os.environ['ICON']) | |
| tetra_path = cached_download(osp.join(self.smpl_data.tedra_dir, | |
| 'tetra_neutral_adult_smpl.npz'), use_auth_token=os.environ['ICON']) | |
| smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult') | |
| pose = torch.cat([global_orient[0], body_pose[0]], dim=0) | |
| smpl_model.set_params(rotation_matrix_to_angle_axis(pose), | |
| beta=betas[0]) | |
| verts = np.concatenate( | |
| [smpl_model.verts, smpl_model.verts_added], | |
| axis=0) * scale.item() + trans.detach().cpu().numpy() | |
| faces = np.loadtxt(cached_download(osp.join(self.smpl_data.tedra_dir, | |
| 'tetrahedrons_neutral_adult.txt'), use_auth_token=os.environ['ICON']), | |
| dtype=np.int32) - 1 | |
| pad_v_num = int(8000 - verts.shape[0]) | |
| pad_f_num = int(25100 - faces.shape[0]) | |
| verts = np.pad(verts, ((0, pad_v_num), (0, 0)), | |
| mode='constant', | |
| constant_values=0.0).astype(np.float32) * 0.5 | |
| faces = np.pad(faces, ((0, pad_f_num), (0, 0)), | |
| mode='constant', | |
| constant_values=0.0).astype(np.int32) | |
| verts[:, 2] *= -1.0 | |
| voxel_dict = { | |
| 'voxel_verts': | |
| torch.from_numpy(verts).to(self.device).unsqueeze(0).float(), | |
| 'voxel_faces': | |
| torch.from_numpy(faces).to(self.device).unsqueeze(0).long(), | |
| 'pad_v_num': | |
| torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(), | |
| 'pad_f_num': | |
| torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long() | |
| } | |
| return voxel_dict | |
| def __getitem__(self, index): | |
| img_path = self.subject_list[index] | |
| img_name = img_path.split("/")[-1].rsplit(".", 1)[0] | |
| if self.seg_dir is None: | |
| img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image( | |
| img_path, self.det, self.hps_type, 512, self.device) | |
| data_dict = { | |
| 'name': img_name, | |
| 'image': img_icon.to(self.device).unsqueeze(0), | |
| 'ori_image': img_ori, | |
| 'mask': img_mask, | |
| 'uncrop_param': uncrop_param | |
| } | |
| else: | |
| img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image( | |
| img_path, self.det, self.hps_type, 512, self.device, | |
| seg_path=os.path.join(self.seg_dir, f'{img_name}.json')) | |
| data_dict = { | |
| 'name': img_name, | |
| 'image': img_icon.to(self.device).unsqueeze(0), | |
| 'ori_image': img_ori, | |
| 'mask': img_mask, | |
| 'uncrop_param': uncrop_param, | |
| 'segmentations': segmentations | |
| } | |
| with torch.no_grad(): | |
| # import ipdb; ipdb.set_trace() | |
| preds_dict = self.hps.forward(img_hps) | |
| data_dict['smpl_faces'] = torch.Tensor( | |
| self.faces.astype(np.int16)).long().unsqueeze(0).to( | |
| self.device) | |
| if self.hps_type == 'pymaf': | |
| output = preds_dict['smpl_out'][-1] | |
| scale, tranX, tranY = output['theta'][0, :3] | |
| data_dict['betas'] = output['pred_shape'] | |
| data_dict['body_pose'] = output['rotmat'][:, 1:] | |
| data_dict['global_orient'] = output['rotmat'][:, 0:1] | |
| data_dict['smpl_verts'] = output['verts'] | |
| elif self.hps_type == 'pare': | |
| data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:] | |
| data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1] | |
| data_dict['betas'] = preds_dict['pred_shape'] | |
| data_dict['smpl_verts'] = preds_dict['smpl_vertices'] | |
| scale, tranX, tranY = preds_dict['pred_cam'][0, :3] | |
| elif self.hps_type == 'pixie': | |
| data_dict.update(preds_dict) | |
| data_dict['body_pose'] = preds_dict['body_pose'] | |
| data_dict['global_orient'] = preds_dict['global_pose'] | |
| data_dict['betas'] = preds_dict['shape'] | |
| data_dict['smpl_verts'] = preds_dict['vertices'] | |
| scale, tranX, tranY = preds_dict['cam'][0, :3] | |
| elif self.hps_type == 'hybrik': | |
| data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:] | |
| data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]] | |
| data_dict['betas'] = preds_dict['pred_shape'] | |
| data_dict['smpl_verts'] = preds_dict['pred_vertices'] | |
| scale, tranX, tranY = preds_dict['pred_camera'][0, :3] | |
| scale = scale * 2 | |
| elif self.hps_type == 'bev': | |
| data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[ | |
| [0], :10].to(self.device).float() | |
| pred_thetas = batch_rodrigues(torch.from_numpy( | |
| preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float() | |
| data_dict['body_pose'] = pred_thetas[1:][None].to(self.device) | |
| data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device) | |
| data_dict['smpl_verts'] = torch.from_numpy( | |
| preds_dict['verts'][[0]]).to(self.device).float() | |
| tranX = preds_dict['cam_trans'][0, 0] | |
| tranY = preds_dict['cam'][0, 1] + 0.28 | |
| scale = preds_dict['cam'][0, 0] * 1.1 | |
| data_dict['scale'] = scale | |
| data_dict['trans'] = torch.tensor( | |
| [tranX, tranY, 0.0]).to(self.device).float() | |
| # data_dict info (key-shape): | |
| # scale, tranX, tranY - tensor.float | |
| # betas - [1,10] / [1, 200] | |
| # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3] | |
| # global_orient - [1, 1, 3, 3] | |
| # smpl_verts - [1, 6890, 3] / [1, 10475, 3] | |
| return data_dict | |
| def render_normal(self, verts, faces): | |
| # render optimized mesh (normal, T_normal, image [-1,1]) | |
| self.render.load_meshes(verts, faces) | |
| return self.render.get_rgb_image() | |
| def render_depth(self, verts, faces): | |
| # render optimized mesh (normal, T_normal, image [-1,1]) | |
| self.render.load_meshes(verts, faces) | |
| return self.render.get_depth_map(cam_ids=[0, 2]) | |