Spaces:
Running
on
Zero
Running
on
Zero
| import time | |
| import torch | |
| import trimesh | |
| import numpy as np | |
| import torch.optim as optim | |
| from torch import autograd | |
| from torch.utils.data import TensorDataset, DataLoader | |
| from .common import make_3d_grid | |
| from .utils import libmcubes | |
| from .utils.libmise import MISE | |
| from .utils.libsimplify import simplify_mesh | |
| from .common import transform_pointcloud | |
| class Generator3D(object): | |
| ''' Generator class for DVRs. | |
| It provides functions to generate the final mesh as well refining options. | |
| Args: | |
| model (nn.Module): trained DVR model | |
| points_batch_size (int): batch size for points evaluation | |
| threshold (float): threshold value | |
| refinement_step (int): number of refinement steps | |
| device (device): pytorch device | |
| resolution0 (int): start resolution for MISE | |
| upsampling steps (int): number of upsampling steps | |
| with_normals (bool): whether normals should be estimated | |
| padding (float): how much padding should be used for MISE | |
| simplify_nfaces (int): number of faces the mesh should be simplified to | |
| refine_max_faces (int): max number of faces which are used as batch | |
| size for refinement process (we added this functionality in this | |
| work) | |
| ''' | |
| def __init__( | |
| self, | |
| model, | |
| points_batch_size=100000, | |
| threshold=0.5, | |
| refinement_step=0, | |
| device=None, | |
| resolution0=16, | |
| upsampling_steps=3, | |
| with_normals=False, | |
| padding=0.1, | |
| simplify_nfaces=None, | |
| with_color=False, | |
| refine_max_faces=10000 | |
| ): | |
| self.model = model.to(device) | |
| self.points_batch_size = points_batch_size | |
| self.refinement_step = refinement_step | |
| self.threshold = threshold | |
| self.device = device | |
| self.resolution0 = resolution0 | |
| self.upsampling_steps = upsampling_steps | |
| self.with_normals = with_normals | |
| self.padding = padding | |
| self.simplify_nfaces = simplify_nfaces | |
| self.with_color = with_color | |
| self.refine_max_faces = refine_max_faces | |
| def generate_mesh(self, data, return_stats=True): | |
| ''' Generates the output mesh. | |
| Args: | |
| data (tensor): data tensor | |
| return_stats (bool): whether stats should be returned | |
| ''' | |
| self.model.eval() | |
| device = self.device | |
| stats_dict = {} | |
| inputs = data.get('inputs', torch.empty(1, 0)).to(device) | |
| kwargs = {} | |
| c = self.model.encode_inputs(inputs) | |
| mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs) | |
| return mesh, stats_dict | |
| def generate_meshes(self, data, return_stats=True): | |
| ''' Generates the output meshes with data of batch size >=1 | |
| Args: | |
| data (tensor): data tensor | |
| return_stats (bool): whether stats should be returned | |
| ''' | |
| self.model.eval() | |
| device = self.device | |
| stats_dict = {} | |
| inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device) | |
| meshes = [] | |
| for i in range(inputs.shape[0]): | |
| input_i = inputs[i].unsqueeze(0) | |
| c = self.model.encode_inputs(input_i) | |
| mesh = self.generate_from_latent(c, stats_dict=stats_dict) | |
| meshes.append(mesh) | |
| return meshes | |
| def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True): | |
| ''' Generates a point cloud from the mesh. | |
| Args: | |
| mesh (trimesh): mesh | |
| data (dict): data dictionary | |
| n_points (int): number of point cloud points | |
| scale_back (bool): whether to undo scaling (requires a scale | |
| matrix in data dictionary) | |
| ''' | |
| pcl = mesh.sample(n_points).astype(np.float32) | |
| if scale_back: | |
| scale_mat = data.get('camera.scale_mat_0', None) | |
| if scale_mat is not None: | |
| pcl = transform_pointcloud(pcl, scale_mat[0]) | |
| else: | |
| print('Warning: No scale_mat found!') | |
| pcl_out = trimesh.Trimesh(vertices=pcl, process=False) | |
| return pcl_out | |
| def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs): | |
| ''' Generates mesh from latent. | |
| Args: | |
| c (tensor): latent conditioned code c | |
| pl (tensor): predicted plane parameters | |
| stats_dict (dict): stats dictionary | |
| ''' | |
| threshold = np.log(self.threshold) - np.log(1. - self.threshold) | |
| t0 = time.time() | |
| # Compute bounding box size | |
| box_size = 1 + self.padding | |
| # Shortcut | |
| if self.upsampling_steps == 0: | |
| nx = self.resolution0 | |
| pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3) | |
| values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() | |
| value_grid = values.reshape(nx, nx, nx) | |
| else: | |
| mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold) | |
| points = mesh_extractor.query() | |
| while points.shape[0] != 0: | |
| # Query points | |
| pointsf = torch.FloatTensor(points).to(self.device) | |
| # Normalize to bounding box | |
| pointsf = 2 * pointsf / mesh_extractor.resolution | |
| pointsf = box_size * (pointsf - 1.0) | |
| # Evaluate model and update | |
| values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() | |
| values = values.astype(np.float64) | |
| mesh_extractor.update(points, values) | |
| points = mesh_extractor.query() | |
| value_grid = mesh_extractor.to_dense() | |
| # Extract mesh | |
| stats_dict['time (eval points)'] = time.time() - t0 | |
| mesh = self.extract_mesh(value_grid, c, stats_dict=stats_dict) | |
| return mesh | |
| def eval_points(self, p, c=None, pl=None, **kwargs): | |
| ''' Evaluates the occupancy values for the points. | |
| Args: | |
| p (tensor): points | |
| c (tensor): latent conditioned code c | |
| ''' | |
| p_split = torch.split(p, self.points_batch_size) | |
| occ_hats = [] | |
| for pi in p_split: | |
| pi = pi.unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| occ_hat = self.model.decode(pi, c, pl, **kwargs).logits | |
| occ_hats.append(occ_hat.squeeze(0).detach().cpu()) | |
| occ_hat = torch.cat(occ_hats, dim=0) | |
| return occ_hat | |
| def extract_mesh(self, occ_hat, c=None, stats_dict=dict()): | |
| ''' Extracts the mesh from the predicted occupancy grid. | |
| Args: | |
| occ_hat (tensor): value grid of occupancies | |
| c (tensor): latent conditioned code c | |
| stats_dict (dict): stats dictionary | |
| ''' | |
| # Some short hands | |
| n_x, n_y, n_z = occ_hat.shape | |
| box_size = 1 + self.padding | |
| threshold = np.log(self.threshold) - np.log(1. - self.threshold) | |
| # Make sure that mesh is watertight | |
| t0 = time.time() | |
| occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6) | |
| vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold) | |
| stats_dict['time (marching cubes)'] = time.time() - t0 | |
| # Strange behaviour in libmcubes: vertices are shifted by 0.5 | |
| vertices -= 0.5 | |
| # Undo padding | |
| vertices -= 1 | |
| # Normalize to bounding box | |
| vertices /= np.array([n_x - 1, n_y - 1, n_z - 1]) | |
| vertices *= 2 | |
| vertices = box_size * (vertices - 1) | |
| # mesh_pymesh = pymesh.form_mesh(vertices, triangles) | |
| # mesh_pymesh = fix_pymesh(mesh_pymesh) | |
| # Estimate normals if needed | |
| if self.with_normals and not vertices.shape[0] == 0: | |
| t0 = time.time() | |
| normals = self.estimate_normals(vertices, c) | |
| stats_dict['time (normals)'] = time.time() - t0 | |
| else: | |
| normals = None | |
| # Create mesh | |
| mesh = trimesh.Trimesh( | |
| vertices, | |
| triangles, | |
| vertex_normals=normals, | |
| # vertex_colors=vertex_colors, | |
| process=False | |
| ) | |
| # Directly return if mesh is empty | |
| if vertices.shape[0] == 0: | |
| return mesh | |
| # TODO: normals are lost here | |
| if self.simplify_nfaces is not None: | |
| t0 = time.time() | |
| mesh = simplify_mesh(mesh, self.simplify_nfaces, 5.) | |
| stats_dict['time (simplify)'] = time.time() - t0 | |
| # Refine mesh | |
| if self.refinement_step > 0: | |
| t0 = time.time() | |
| self.refine_mesh(mesh, occ_hat, c) | |
| stats_dict['time (refine)'] = time.time() - t0 | |
| # Estimate Vertex Colors | |
| if self.with_color and not vertices.shape[0] == 0: | |
| t0 = time.time() | |
| vertex_colors = self.estimate_colors(np.array(mesh.vertices), c) | |
| stats_dict['time (color)'] = time.time() - t0 | |
| mesh = trimesh.Trimesh( | |
| vertices=mesh.vertices, | |
| faces=mesh.faces, | |
| vertex_normals=mesh.vertex_normals, | |
| vertex_colors=vertex_colors, | |
| process=False | |
| ) | |
| return mesh | |
| def estimate_colors(self, vertices, c=None): | |
| ''' Estimates vertex colors by evaluating the texture field. | |
| Args: | |
| vertices (numpy array): vertices of the mesh | |
| c (tensor): latent conditioned code c | |
| ''' | |
| device = self.device | |
| vertices = torch.FloatTensor(vertices) | |
| vertices_split = torch.split(vertices, self.points_batch_size) | |
| colors = [] | |
| for vi in vertices_split: | |
| vi = vi.to(device) | |
| with torch.no_grad(): | |
| ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu() | |
| colors.append(ci) | |
| colors = np.concatenate(colors, axis=0) | |
| colors = np.clip(colors, 0, 1) | |
| colors = (colors * 255).astype(np.uint8) | |
| colors = np.concatenate( | |
| [colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1 | |
| ) | |
| return colors | |
| def estimate_normals(self, vertices, c=None): | |
| ''' Estimates the normals by computing the gradient of the objective. | |
| Args: | |
| vertices (numpy array): vertices of the mesh | |
| z (tensor): latent code z | |
| c (tensor): latent conditioned code c | |
| ''' | |
| device = self.device | |
| vertices = torch.FloatTensor(vertices) | |
| vertices_split = torch.split(vertices, self.points_batch_size) | |
| normals = [] | |
| c = c.unsqueeze(0) | |
| for vi in vertices_split: | |
| vi = vi.unsqueeze(0).to(device) | |
| vi.requires_grad_() | |
| occ_hat = self.model.decode(vi, c).logits | |
| out = occ_hat.sum() | |
| out.backward() | |
| ni = -vi.grad | |
| ni = ni / torch.norm(ni, dim=-1, keepdim=True) | |
| ni = ni.squeeze(0).cpu().numpy() | |
| normals.append(ni) | |
| normals = np.concatenate(normals, axis=0) | |
| return normals | |
| def refine_mesh(self, mesh, occ_hat, c=None): | |
| ''' Refines the predicted mesh. | |
| Args: | |
| mesh (trimesh object): predicted mesh | |
| occ_hat (tensor): predicted occupancy grid | |
| c (tensor): latent conditioned code c | |
| ''' | |
| self.model.eval() | |
| # Some shorthands | |
| n_x, n_y, n_z = occ_hat.shape | |
| assert (n_x == n_y == n_z) | |
| # threshold = np.log(self.threshold) - np.log(1. - self.threshold) | |
| threshold = self.threshold | |
| # Vertex parameter | |
| v0 = torch.FloatTensor(mesh.vertices).to(self.device) | |
| v = torch.nn.Parameter(v0.clone()) | |
| # Faces of mesh | |
| faces = torch.LongTensor(mesh.faces) | |
| # detach c; otherwise graph needs to be retained | |
| # caused by new Pytorch version? | |
| c = c.detach() | |
| # Start optimization | |
| optimizer = optim.RMSprop([v], lr=1e-5) | |
| # Dataset | |
| ds_faces = TensorDataset(faces) | |
| dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True) | |
| # We updated the refinement algorithm to subsample faces; this is | |
| # usefull when using a high extraction resolution / when working on | |
| # small GPUs | |
| it_r = 0 | |
| while it_r < self.refinement_step: | |
| for f_it in dataloader: | |
| f_it = f_it[0].to(self.device) | |
| optimizer.zero_grad() | |
| # Loss | |
| face_vertex = v[f_it] | |
| eps = np.random.dirichlet((0.5, 0.5, 0.5), size=f_it.shape[0]) | |
| eps = torch.FloatTensor(eps).to(self.device) | |
| face_point = (face_vertex * eps[:, :, None]).sum(dim=1) | |
| face_v1 = face_vertex[:, 1, :] - face_vertex[:, 0, :] | |
| face_v2 = face_vertex[:, 2, :] - face_vertex[:, 1, :] | |
| face_normal = torch.cross(face_v1, face_v2) | |
| face_normal = face_normal / \ | |
| (face_normal.norm(dim=1, keepdim=True) + 1e-10) | |
| face_value = torch.cat( | |
| [ | |
| torch.sigmoid(self.model.decode(p_split, c).logits) | |
| for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1) | |
| ], | |
| dim=1 | |
| ) | |
| normal_target = -autograd.grad([face_value.sum()], [face_point], | |
| create_graph=True)[0] | |
| normal_target = \ | |
| normal_target / \ | |
| (normal_target.norm(dim=1, keepdim=True) + 1e-10) | |
| loss_target = (face_value - threshold).pow(2).mean() | |
| loss_normal = \ | |
| (face_normal - normal_target).pow(2).sum(dim=1).mean() | |
| loss = loss_target + 0.01 * loss_normal | |
| # Update | |
| loss.backward() | |
| optimizer.step() | |
| # Update it_r | |
| it_r += 1 | |
| if it_r >= self.refinement_step: | |
| break | |
| mesh.vertices = v.data.cpu().numpy() | |
| return mesh | |