Spaces:
Running
on
Zero
Running
on
Zero
| from core.remesh import calc_vertex_normals | |
| from core.opt import MeshOptimizer | |
| from utils.func import make_sparse_camera, make_round_views | |
| from utils.render import NormalsRenderer | |
| import torch.optim as optim | |
| from tqdm import tqdm | |
| from utils.video_utils import write_video | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| import kornia | |
| import torch | |
| import torch.nn as nn | |
| import trimesh | |
| from icecream import ic | |
| from utils.project_mesh import multiview_color_projection, get_cameras_list | |
| from utils.mesh_utils import to_py3d_mesh, rot6d_to_rotmat, tensor2variable | |
| from utils.project_mesh import project_color, get_cameras_list | |
| from utils.smpl_util import SMPLX | |
| from lib.dataset.mesh_util import apply_vertex_mask, part_removal, poisson, keep_largest | |
| from scipy.spatial.transform import Rotation as R | |
| from scipy.spatial import KDTree | |
| import argparse | |
| #### ------------------- config---------------------- | |
| bg_color = np.array([1,1,1]) | |
| class colorModel(nn.Module): | |
| def __init__(self, renderer, v, f, c): | |
| super().__init__() | |
| self.renderer = renderer | |
| self.v = v | |
| self.f = f | |
| self.colors = nn.Parameter(c, requires_grad=True) | |
| self.bg_color = torch.from_numpy(bg_color).float().to(self.colors.device) | |
| def forward(self, return_mask=False): | |
| rgba = self.renderer.render(self.v, self.f, colors=self.colors) | |
| if return_mask: | |
| return rgba | |
| else: | |
| mask = rgba[..., 3:] | |
| return rgba[..., :3] * mask + self.bg_color * (1 - mask) | |
| def scale_mesh(vert): | |
| min_bbox, max_bbox = vert.min(0)[0], vert.max(0)[0] | |
| center = (min_bbox + max_bbox) / 2 | |
| offset = -center | |
| vert = vert + offset | |
| max_dist = torch.max(torch.sqrt(torch.sum(vert**2, dim=1))) | |
| scale = 1.0 / max_dist | |
| return scale, offset | |
| def save_mesh(save_name, vertices, faces, color=None): | |
| trimesh.Trimesh( | |
| vertices.detach().cpu().numpy(), | |
| faces.detach().cpu().numpy(), | |
| vertex_colors=(color.detach().cpu().numpy() * 255).astype(np.uint8) if color is not None else None) \ | |
| .export(save_name) | |
| class ReMesh: | |
| def __init__(self, opt, econ_dataset): | |
| self.opt = opt | |
| self.device = torch.device(f"cuda:{opt.gpu_id}" if torch.cuda.is_available() else "cpu") | |
| self.num_view = opt.num_view | |
| self.out_path = opt.res_path | |
| os.makedirs(self.out_path, exist_ok=True) | |
| self.resolution = opt.resolution | |
| self.views = ['front_face', 'front_right', 'right', 'back', 'left', 'front_left' ] | |
| self.weights = torch.Tensor([1., 0.4, 0.8, 1.0, 0.8, 0.4]).view(6,1,1,1).to(self.device) | |
| self.renderer = self.prepare_render() | |
| # pose prediction | |
| self.econ_dataset = econ_dataset | |
| self.smplx_face = torch.Tensor(econ_dataset.faces.astype(np.int64)).long().to(self.device) | |
| def prepare_render(self): | |
| ### ------------------- prepare camera and renderer---------------------- | |
| mv, proj = make_sparse_camera(self.opt.cam_path, self.opt.scale, views=[0,1,2,4,6,7], device=self.device) | |
| renderer = NormalsRenderer(mv, proj, [self.resolution, self.resolution], device=self.device) | |
| return renderer | |
| def proj_texture(self, fused_images, vertices, faces): | |
| mesh = to_py3d_mesh(vertices, faces) | |
| mesh = mesh.to(self.device) | |
| camera_focal = 1/2 | |
| cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal) | |
| mesh = multiview_color_projection(mesh, fused_images, camera_focal=camera_focal, resolution=self.resolution, weights=self.weights.squeeze().cpu().numpy(), | |
| device=self.device, complete_unseen=True, confidence_threshold=0.2, cameras_list=cameras_list) | |
| return mesh | |
| def get_invisible_idx(self, imgs, vertices, faces): | |
| mesh = to_py3d_mesh(vertices, faces) | |
| mesh = mesh.to(self.device) | |
| camera_focal = 1/2 | |
| if self.num_view == 6: | |
| cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal) | |
| elif self.num_view == 4: | |
| cameras_list = get_cameras_list([0, 90, 180, 270], device=self.device, focal=camera_focal) | |
| valid_vert_id = [] | |
| vertices_colors = torch.zeros((vertices.shape[0], 3)).float().to(self.device) | |
| valid_cnt = torch.zeros((vertices.shape[0])).to(self.device) | |
| for cam, img, weight in zip(cameras_list, imgs, self.weights.squeeze()): | |
| ret = project_color(mesh, cam, img, eps=0.01, resolution=self.resolution, device=self.device) | |
| # print(ret['valid_colors'].shape) | |
| valid_cnt[ret['valid_verts']] += weight | |
| vertices_colors[ret['valid_verts']] += ret['valid_colors']*weight | |
| valid_mask = valid_cnt > 1 | |
| invalid_mask = valid_cnt < 1 | |
| vertices_colors[valid_mask] /= valid_cnt[valid_mask][:, None] | |
| # visibility | |
| invisible_vert = valid_cnt < 1 | |
| invisible_vert_indices = torch.nonzero(invisible_vert).squeeze() | |
| # vertices_colors[invalid_vert] = torch.tensor([1.0, 0.0, 0.0]).float().to("cuda") | |
| return vertices_colors, invisible_vert_indices | |
| def inpaint_missed_colors(self, all_vertices, all_colors, missing_indices): | |
| all_vertices = all_vertices.detach().cpu().numpy() | |
| all_colors = all_colors.detach().cpu().numpy() | |
| missing_indices = missing_indices.detach().cpu().numpy() | |
| non_missing_indices = np.setdiff1d(np.arange(len(all_vertices)), missing_indices) | |
| kdtree = KDTree(all_vertices[non_missing_indices]) | |
| for missing_index in missing_indices: | |
| missing_vertex = all_vertices[missing_index] | |
| _, nearest_index = kdtree.query(missing_vertex.reshape(1, -1)) | |
| interpolated_color = all_colors[non_missing_indices[nearest_index]] | |
| all_colors[missing_index] = interpolated_color | |
| return torch.from_numpy(all_colors).to(self.device) | |
| def load_training_data(self, case): | |
| ###------------------ load target images ------------------------------- | |
| kernal = torch.ones(3, 3) | |
| erode_iters = 2 | |
| normals = [] | |
| masks = [] | |
| colors = [] | |
| for idx, view in enumerate(self.views): | |
| # for idx in [0,2,3,4]: | |
| normal = Image.open(f'{self.opt.mv_path}/{case}/normals_{view}_masked.png') | |
| # normal = Image.open(f'{data_path}/{case}/normals/{idx:02d}_rgba.png') | |
| normal = normal.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR) | |
| normal = np.array(normal).astype(np.float32) / 255. | |
| mask = normal[..., 3:] # alpha | |
| mask_troch = torch.from_numpy(mask).unsqueeze(0) | |
| for _ in range(erode_iters): | |
| mask_torch = kornia.morphology.erosion(mask_troch, kernal) | |
| mask_erode = mask_torch.squeeze(0).numpy() | |
| masks.append(mask_erode) | |
| normal = normal[..., :3] * mask_erode | |
| normals.append(normal) | |
| color = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png') | |
| color = color.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR) | |
| color = np.array(color).astype(np.float32) / 255. | |
| color_mask = color[..., 3:] # alpha | |
| # color_dilate = color[..., :3] * color_mask + bg_color * (1 - color_mask) | |
| color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode) | |
| colors.append(color_dilate) | |
| masks = np.stack(masks, 0) | |
| masks = torch.from_numpy(masks).to(self.device) | |
| normals = np.stack(normals, 0) | |
| target_normals = torch.from_numpy(normals).to(self.device) | |
| colors = np.stack(colors, 0) | |
| target_colors = torch.from_numpy(colors).to(self.device) | |
| return masks, target_colors, target_normals | |
| def preprocess(self, color_pils, normal_pils): | |
| ###------------------ load target images ------------------------------- | |
| kernal = torch.ones(3, 3) | |
| erode_iters = 2 | |
| normals = [] | |
| masks = [] | |
| colors = [] | |
| for normal, color in zip(normal_pils, color_pils): | |
| normal = normal.resize((self.resolution, self.resolution), Image.BILINEAR) | |
| normal = np.array(normal).astype(np.float32) / 255. | |
| mask = normal[..., 3:] # alpha | |
| mask_troch = torch.from_numpy(mask).unsqueeze(0) | |
| for _ in range(erode_iters): | |
| mask_torch = kornia.morphology.erosion(mask_troch, kernal) | |
| mask_erode = mask_torch.squeeze(0).numpy() | |
| masks.append(mask_erode) | |
| normal = normal[..., :3] * mask_erode | |
| normals.append(normal) | |
| color = color.resize((self.resolution, self.resolution), Image.BILINEAR) | |
| color = np.array(color).astype(np.float32) / 255. | |
| color_mask = color[..., 3:] # alpha | |
| # color_dilate = color[..., :3] * color_mask + bg_color * (1 - color_mask) | |
| color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode) | |
| colors.append(color_dilate) | |
| masks = np.stack(masks, 0) | |
| masks = torch.from_numpy(masks).to(self.device) | |
| normals = np.stack(normals, 0) | |
| target_normals = torch.from_numpy(normals).to(self.device) | |
| colors = np.stack(colors, 0) | |
| target_colors = torch.from_numpy(colors).to(self.device) | |
| return masks, target_colors, target_normals | |
| def optimize_case(self, case, pose, clr_img, nrm_img, opti_texture=True): | |
| case_path = f'{self.out_path}/{case}' | |
| os.makedirs(case_path, exist_ok=True) | |
| if clr_img is not None: | |
| masks, target_colors, target_normals = self.preprocess(clr_img, nrm_img) | |
| else: | |
| masks, target_colors, target_normals = self.load_training_data(case) | |
| # rotation | |
| rz = R.from_euler('z', 180, degrees=True).as_matrix() | |
| ry = R.from_euler('y', 180, degrees=True).as_matrix() | |
| rz = torch.from_numpy(rz).float().to(self.device) | |
| ry = torch.from_numpy(ry).float().to(self.device) | |
| scale, offset = None, None | |
| global_orient = pose["global_orient"] # pymaf_res[idx]['smplx_params']['body_pose'][:, :1, :, :2].to(device).reshape(1, 1, -1) # data["global_orient"] | |
| body_pose = pose["body_pose"] # pymaf_res[idx]['smplx_params']['body_pose'][:, 1:22, :, :2].to(device).reshape(1, 21, -1) # data["body_pose"] | |
| left_hand_pose = pose["left_hand_pose"] # pymaf_res[idx]['smplx_params']['left_hand_pose'][:, :, :, :2].to(device).reshape(1, 15, -1) | |
| right_hand_pose = pose["right_hand_pose"] # pymaf_res[idx]['smplx_params']['right_hand_pose'][:, :, :, :2].to(device).reshape(1, 15, -1) | |
| beta = pose["betas"] | |
| # The optimizer and variables | |
| optimed_pose = torch.tensor(body_pose, | |
| device=self.device, | |
| requires_grad=True) # [1,23,3,3] | |
| optimed_trans = torch.tensor(pose["trans"], | |
| device=self.device, | |
| requires_grad=True) # [3] | |
| optimed_betas = torch.tensor(beta, | |
| device=self.device, | |
| requires_grad=True) # [1,200] | |
| optimed_orient = torch.tensor(global_orient, | |
| device=self.device, | |
| requires_grad=True) # [1,1,3,3] | |
| optimed_rhand = torch.tensor(right_hand_pose, | |
| device=self.device, | |
| requires_grad=True) | |
| optimed_lhand = torch.tensor(left_hand_pose, | |
| device=self.device, | |
| requires_grad=True) | |
| optimed_params = [ | |
| {'params': [optimed_lhand, optimed_rhand], 'lr': 1e-3}, | |
| {'params': [optimed_betas, optimed_trans, optimed_orient, optimed_pose], 'lr': 3e-3}, | |
| ] | |
| optimizer_smpl = torch.optim.Adam( | |
| optimed_params, | |
| amsgrad=True, | |
| ) | |
| scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer_smpl, | |
| mode="min", | |
| factor=0.5, | |
| verbose=0, | |
| min_lr=1e-5, | |
| patience=5, | |
| ) | |
| smpl_steps = 100 | |
| for i in tqdm(range(smpl_steps)): | |
| optimizer_smpl.zero_grad() | |
| # 6d_rot to rot_mat | |
| optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view( | |
| -1, 6)).unsqueeze(0) | |
| optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view( | |
| -1, 6)).unsqueeze(0) | |
| smpl_verts, smpl_landmarks, smpl_joints = self.econ_dataset.smpl_model( | |
| shape_params=optimed_betas, | |
| expression_params=tensor2variable(pose["exp"], self.device), | |
| body_pose=optimed_pose_mat, | |
| global_pose=optimed_orient_mat, | |
| jaw_pose=tensor2variable(pose["jaw_pose"], self.device), | |
| left_hand_pose=optimed_lhand, | |
| right_hand_pose=optimed_rhand, | |
| ) | |
| smpl_verts = smpl_verts + optimed_trans | |
| v_smpl = torch.matmul(torch.matmul(smpl_verts.squeeze(0), rz.T), ry.T) | |
| if scale is None: | |
| scale, offset = scale_mesh(v_smpl.detach()) | |
| v_smpl = (v_smpl + offset) * scale * 2 | |
| # if i == 0: | |
| # save_mesh(f'{case_path}/{case}_init_smpl.obj', v_smpl, self.smplx_face) | |
| # exit() | |
| normals = calc_vertex_normals(v_smpl, self.smplx_face) | |
| nrm = self.renderer.render(v_smpl, self.smplx_face, normals=normals) | |
| masks_ = nrm[..., 3:] | |
| smpl_mask_loss = ((masks_ - masks) * self.weights).abs().mean() | |
| smpl_nrm_loss = ((nrm[..., :3] - target_normals) * self.weights).abs().mean() | |
| smpl_loss = smpl_mask_loss + smpl_nrm_loss | |
| # smpl_loss = smpl_mask_loss | |
| smpl_loss.backward() | |
| optimizer_smpl.step() | |
| scheduler_smpl.step(smpl_loss) | |
| mesh_smpl = trimesh.Trimesh(vertices=v_smpl.detach().cpu().numpy(), faces=self.smplx_face.detach().cpu().numpy()) | |
| nrm_opt = MeshOptimizer(v_smpl.detach(), self.smplx_face.detach(), edge_len_lims=[0.01, 0.1]) | |
| vertices, faces = nrm_opt.vertices, nrm_opt.faces | |
| # ###----------------------- optimization iterations------------------------------------- | |
| for i in tqdm(range(self.opt.iters)): | |
| nrm_opt.zero_grad() | |
| normals = calc_vertex_normals(vertices,faces) | |
| nrm = self.renderer.render(vertices,faces, normals=normals) | |
| normals = nrm[..., :3] | |
| # if i < 800: | |
| loss = ((normals-target_normals) * self.weights).abs().mean() | |
| # else: | |
| # loss = ((normals-target_images) * masks).abs().mean() | |
| alpha = nrm[..., 3:] | |
| loss += ((alpha - masks) * self.weights).abs().mean() | |
| loss.backward() | |
| nrm_opt.step() | |
| vertices,faces = nrm_opt.remesh() | |
| if self.opt.debug and i % self.opt.snapshot_step == 0: | |
| import imageio | |
| os.makedirs(f'{case_path}/normals', exist_ok=True) | |
| imageio.imwrite(f'{case_path}/normals/{i:02d}.png',(nrm.detach()[0,:,:,:3]*255).clamp(max=255).type(torch.uint8).cpu().numpy()) | |
| # mesh_remeshed = trimesh.Trimesh(vertices=vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy()) | |
| # mesh_remeshed.export(f'{case_path}/{case}_remeshed_step{i}.obj') | |
| torch.cuda.empty_cache() | |
| mesh_remeshed = trimesh.Trimesh(vertices=vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy()) | |
| mesh_remeshed.export(f'{case_path}/{case}_remeshed.obj') | |
| # save_mesh(case, vertices, faces) | |
| vertices = vertices.detach() | |
| faces = faces.detach() | |
| #### replace hand | |
| smpl_data = SMPLX() | |
| if self.opt.replace_hand and True in pose['hands_visibility'][0]: | |
| hand_mask = torch.zeros(smpl_data.smplx_verts.shape[0], ) | |
| if pose['hands_visibility'][0][0]: | |
| hand_mask.index_fill_( | |
| 0, torch.tensor(smpl_data.smplx_mano_vid_dict["left_hand"]), 1.0 | |
| ) | |
| if pose['hands_visibility'][0][1]: | |
| hand_mask.index_fill_( | |
| 0, torch.tensor(smpl_data.smplx_mano_vid_dict["right_hand"]), 1.0 | |
| ) | |
| hand_mesh = apply_vertex_mask(mesh_smpl.copy(), hand_mask) | |
| body_mesh = part_removal( | |
| mesh_remeshed.copy(), | |
| hand_mesh, | |
| 0.08, | |
| self.device, | |
| mesh_smpl.copy(), | |
| region="hand" | |
| ) | |
| final = poisson(sum([hand_mesh, body_mesh]), f'{case_path}/{case}_final.obj', 10, False) | |
| else: | |
| final = poisson(mesh_remeshed, f'{case_path}/{case}_final.obj', 10, False) | |
| vertices = torch.from_numpy(final.vertices).float().to(self.device) | |
| faces = torch.from_numpy(final.faces).long().to(self.device) | |
| # Differing from paper, we use the texturing method in Unique3D | |
| masked_color = [] | |
| for tmp in clr_img: | |
| # tmp = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png') | |
| tmp = tmp.resize((self.resolution, self.resolution), Image.BILINEAR) | |
| tmp = np.array(tmp).astype(np.float32) / 255. | |
| masked_color.append(torch.from_numpy(tmp).permute(2, 0, 1).to(self.device)) | |
| meshes = self.proj_texture(masked_color, vertices, faces) | |
| vertices = meshes.verts_packed().float() | |
| faces = meshes.faces_packed().long() | |
| colors = meshes.textures.verts_features_packed().float() | |
| save_mesh(f'./{case_path}/result_clr_scale{self.opt.scale}_{case}.obj', vertices, faces, colors) | |
| self.evaluate(vertices, colors, faces, save_path=f'{case_path}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True) | |
| def evaluate(self, target_vertices, target_colors, target_faces, save_path=None, save_nrm=False): | |
| mv, proj = make_round_views(60, self.opt.scale, device=self.device) | |
| renderer = NormalsRenderer(mv, proj, [512, 512], device=self.device) | |
| target_images = renderer.render(target_vertices,target_faces, colors=target_colors) | |
| target_images = target_images.detach().cpu().numpy() | |
| target_images = target_images[..., :3] * target_images[..., 3:4] + bg_color * (1 - target_images[..., 3:4]) | |
| target_images = (target_images.clip(0, 1) * 255).astype(np.uint8) | |
| if save_nrm: | |
| target_normals = calc_vertex_normals(target_vertices, target_faces) | |
| # target_normals[:, 2] *= -1 | |
| target_normals = renderer.render(target_vertices, target_faces, normals=target_normals) | |
| target_normals = target_normals.detach().cpu().numpy() | |
| target_normals = target_normals[..., :3] * target_normals[..., 3:4] + bg_color * (1 - target_normals[..., 3:4]) | |
| target_normals = (target_normals.clip(0, 1) * 255).astype(np.uint8) | |
| frames = [np.concatenate([img, nrm], 1) for img, nrm in zip(target_images, target_normals)] | |
| else: | |
| frames = [img for img in target_images] | |
| if save_path is not None: | |
| write_video(frames, fps=25, save_path=save_path) | |
| return frames | |
| def run(self): | |
| cases = sorted(os.listdir(self.opt.imgs_path)) | |
| for idx in range(len(cases)): | |
| case = cases[idx].split('.')[0] | |
| print(f'Processing {case}') | |
| pose = self.econ_dataset.__getitem__(idx) | |
| v, f, c = self.optimize_case(case, pose, None, None, opti_texture=True) | |
| self.evaluate(v, c, f, save_path=f'{self.opt.res_path}/{case}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", help="path to the yaml configs file", default='config.yaml') | |
| args, extras = parser.parse_known_args() | |
| opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) | |
| from econdataset import SMPLDataset | |
| dataset_param = {'image_dir': opt.imgs_path, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'} | |
| econdata = SMPLDataset(dataset_param, device='cuda') | |
| EHuman = ReMesh(opt, econdata) | |
| EHuman.run() | |