Spaces:
Runtime error
Runtime error
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from jaxtyping import Float, Integer | |
| from torch import Tensor | |
| from .mesh import Mesh | |
| class IsosurfaceHelper(nn.Module): | |
| points_range: Tuple[float, float] = (0, 1) | |
| def grid_vertices(self) -> Float[Tensor, "N 3"]: | |
| raise NotImplementedError | |
| def requires_instance_per_batch(self) -> bool: | |
| return False | |
| class MarchingTetrahedraHelper(IsosurfaceHelper): | |
| def __init__(self, resolution: int, tets_path: str): | |
| super().__init__() | |
| self.resolution = resolution | |
| self.tets_path = tets_path | |
| self.triangle_table: Float[Tensor, "..."] | |
| self.register_buffer( | |
| "triangle_table", | |
| torch.as_tensor( | |
| [ | |
| [-1, -1, -1, -1, -1, -1], | |
| [1, 0, 2, -1, -1, -1], | |
| [4, 0, 3, -1, -1, -1], | |
| [1, 4, 2, 1, 3, 4], | |
| [3, 1, 5, -1, -1, -1], | |
| [2, 3, 0, 2, 5, 3], | |
| [1, 4, 0, 1, 5, 4], | |
| [4, 2, 5, -1, -1, -1], | |
| [4, 5, 2, -1, -1, -1], | |
| [4, 1, 0, 4, 5, 1], | |
| [3, 2, 0, 3, 5, 2], | |
| [1, 3, 5, -1, -1, -1], | |
| [4, 1, 2, 4, 3, 1], | |
| [3, 0, 4, -1, -1, -1], | |
| [2, 0, 1, -1, -1, -1], | |
| [-1, -1, -1, -1, -1, -1], | |
| ], | |
| dtype=torch.long, | |
| ), | |
| persistent=False, | |
| ) | |
| self.num_triangles_table: Integer[Tensor, "..."] | |
| self.register_buffer( | |
| "num_triangles_table", | |
| torch.as_tensor( | |
| [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long | |
| ), | |
| persistent=False, | |
| ) | |
| self.base_tet_edges: Integer[Tensor, "..."] | |
| self.register_buffer( | |
| "base_tet_edges", | |
| torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), | |
| persistent=False, | |
| ) | |
| tets = np.load(self.tets_path) | |
| self._grid_vertices: Float[Tensor, "..."] | |
| self.register_buffer( | |
| "_grid_vertices", | |
| torch.from_numpy(tets["vertices"]).float(), | |
| persistent=False, | |
| ) | |
| self.indices: Integer[Tensor, "..."] | |
| self.register_buffer( | |
| "indices", torch.from_numpy(tets["indices"]).long(), persistent=False | |
| ) | |
| self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None | |
| center_indices, boundary_indices = self.get_center_boundary_index( | |
| self._grid_vertices | |
| ) | |
| self.center_indices: Integer[Tensor, "..."] | |
| self.register_buffer("center_indices", center_indices, persistent=False) | |
| self.boundary_indices: Integer[Tensor, "..."] | |
| self.register_buffer("boundary_indices", boundary_indices, persistent=False) | |
| def get_center_boundary_index(self, verts): | |
| magn = torch.sum(verts**2, dim=-1) | |
| center_idx = torch.argmin(magn) | |
| boundary_neg = verts == verts.max() | |
| boundary_pos = verts == verts.min() | |
| boundary = torch.bitwise_or(boundary_pos, boundary_neg) | |
| boundary = torch.sum(boundary.float(), dim=-1) | |
| boundary_idx = torch.nonzero(boundary) | |
| return center_idx, boundary_idx.squeeze(dim=-1) | |
| def normalize_grid_deformation( | |
| self, grid_vertex_offsets: Float[Tensor, "Nv 3"] | |
| ) -> Float[Tensor, "Nv 3"]: | |
| return ( | |
| (self.points_range[1] - self.points_range[0]) | |
| / self.resolution # half tet size is approximately 1 / self.resolution | |
| * torch.tanh(grid_vertex_offsets) | |
| ) # FIXME: hard-coded activation | |
| def grid_vertices(self) -> Float[Tensor, "Nv 3"]: | |
| return self._grid_vertices | |
| def all_edges(self) -> Integer[Tensor, "Ne 2"]: | |
| if self._all_edges is None: | |
| # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) | |
| edges = torch.tensor( | |
| [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], | |
| dtype=torch.long, | |
| device=self.indices.device, | |
| ) | |
| _all_edges = self.indices[:, edges].reshape(-1, 2) | |
| _all_edges_sorted = torch.sort(_all_edges, dim=1)[0] | |
| _all_edges = torch.unique(_all_edges_sorted, dim=0) | |
| self._all_edges = _all_edges | |
| return self._all_edges | |
| def sort_edges(self, edges_ex2): | |
| with torch.no_grad(): | |
| order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() | |
| order = order.unsqueeze(dim=1) | |
| a = torch.gather(input=edges_ex2, index=order, dim=1) | |
| b = torch.gather(input=edges_ex2, index=1 - order, dim=1) | |
| return torch.stack([a, b], -1) | |
| def _forward(self, pos_nx3, sdf_n, tet_fx4): | |
| with torch.no_grad(): | |
| occ_n = sdf_n > 0 | |
| occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) | |
| occ_sum = torch.sum(occ_fx4, -1) | |
| valid_tets = (occ_sum > 0) & (occ_sum < 4) | |
| occ_sum = occ_sum[valid_tets] | |
| # find all vertices | |
| all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) | |
| all_edges = self.sort_edges(all_edges) | |
| unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) | |
| unique_edges = unique_edges.long() | |
| mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 | |
| mapping = ( | |
| torch.ones( | |
| (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device | |
| ) | |
| * -1 | |
| ) | |
| mapping[mask_edges] = torch.arange( | |
| mask_edges.sum(), dtype=torch.long, device=pos_nx3.device | |
| ) | |
| idx_map = mapping[idx_map] # map edges to verts | |
| interp_v = unique_edges[mask_edges] | |
| edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) | |
| edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) | |
| edges_to_interp_sdf[:, -1] *= -1 | |
| denominator = edges_to_interp_sdf.sum(1, keepdim=True) | |
| edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator | |
| verts = (edges_to_interp * edges_to_interp_sdf).sum(1) | |
| idx_map = idx_map.reshape(-1, 6) | |
| v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) | |
| tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) | |
| num_triangles = self.num_triangles_table[tetindex] | |
| # Generate triangle indices | |
| faces = torch.cat( | |
| ( | |
| torch.gather( | |
| input=idx_map[num_triangles == 1], | |
| dim=1, | |
| index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], | |
| ).reshape(-1, 3), | |
| torch.gather( | |
| input=idx_map[num_triangles == 2], | |
| dim=1, | |
| index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], | |
| ).reshape(-1, 3), | |
| ), | |
| dim=0, | |
| ) | |
| return verts, faces | |
| def forward( | |
| self, | |
| level: Float[Tensor, "N3 1"], | |
| deformation: Optional[Float[Tensor, "N3 3"]] = None, | |
| ) -> Mesh: | |
| if deformation is not None: | |
| grid_vertices = self.grid_vertices + self.normalize_grid_deformation( | |
| deformation | |
| ) | |
| else: | |
| grid_vertices = self.grid_vertices | |
| v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) | |
| mesh = Mesh( | |
| v_pos=v_pos, | |
| t_pos_idx=t_pos_idx, | |
| # extras | |
| grid_vertices=grid_vertices, | |
| tet_edges=self.all_edges, | |
| grid_level=level, | |
| grid_deformation=deformation, | |
| ) | |
| return mesh | |