Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass, field | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import craftsman | |
| from .utils import ( | |
| Mesh, | |
| IsosurfaceHelper, | |
| MarchingCubeCPUHelper, | |
| MarchingTetrahedraHelper, | |
| ) | |
| from craftsman.utils.base import BaseModule | |
| from craftsman.utils.ops import chunk_batch, scale_tensor | |
| from craftsman.utils.typing import * | |
| class BaseGeometry(BaseModule): | |
| class Config(BaseModule.Config): | |
| pass | |
| cfg: Config | |
| def create_from( | |
| other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs | |
| ) -> "BaseGeometry": | |
| raise TypeError( | |
| f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" | |
| ) | |
| def export(self, *args, **kwargs) -> Dict[str, Any]: | |
| return {} | |
| class BaseImplicitGeometry(BaseGeometry): | |
| class Config(BaseGeometry.Config): | |
| radius: float = 1.0 | |
| isosurface: bool = True | |
| isosurface_method: str = "mt" | |
| isosurface_resolution: int = 128 | |
| isosurface_threshold: Union[float, str] = 0.0 | |
| isosurface_chunk: int = 0 | |
| isosurface_coarse_to_fine: bool = True | |
| isosurface_deformable_grid: bool = False | |
| isosurface_remove_outliers: bool = True | |
| isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 | |
| cfg: Config | |
| def configure(self) -> None: | |
| self.bbox: Float[Tensor, "2 3"] | |
| self.register_buffer( | |
| "bbox", | |
| torch.as_tensor( | |
| [ | |
| [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], | |
| [self.cfg.radius, self.cfg.radius, self.cfg.radius], | |
| ], | |
| dtype=torch.float32, | |
| ), | |
| ) | |
| self.isosurface_helper: Optional[IsosurfaceHelper] = None | |
| self.unbounded: bool = False | |
| def _initilize_isosurface_helper(self): | |
| if self.cfg.isosurface and self.isosurface_helper is None: | |
| if self.cfg.isosurface_method == "mc-cpu": | |
| self.isosurface_helper = MarchingCubeCPUHelper( | |
| self.cfg.isosurface_resolution | |
| ).to(self.device) | |
| elif self.cfg.isosurface_method == "mt": | |
| self.isosurface_helper = MarchingTetrahedraHelper( | |
| self.cfg.isosurface_resolution, | |
| f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", | |
| ).to(self.device) | |
| else: | |
| raise AttributeError( | |
| "Unknown isosurface method {self.cfg.isosurface_method}" | |
| ) | |
| def forward( | |
| self, points: Float[Tensor, "*N Di"], output_normal: bool = False | |
| ) -> Dict[str, Float[Tensor, "..."]]: | |
| raise NotImplementedError | |
| def forward_field( | |
| self, points: Float[Tensor, "*N Di"] | |
| ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: | |
| # return the value of the implicit field, could be density / signed distance | |
| # also return a deformation field if the grid vertices can be optimized | |
| raise NotImplementedError | |
| def forward_level( | |
| self, field: Float[Tensor, "*N 1"], threshold: float | |
| ) -> Float[Tensor, "*N 1"]: | |
| # return the value of the implicit field, where the zero level set represents the surface | |
| raise NotImplementedError | |
| def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh: | |
| def batch_func(x): | |
| # scale to bbox as the input vertices are in [0, 1] | |
| field, deformation = self.forward_field( | |
| scale_tensor( | |
| x.to(bbox.device), self.isosurface_helper.points_range, bbox | |
| ), | |
| ) | |
| field = field.to( | |
| x.device | |
| ) # move to the same device as the input (could be CPU) | |
| if deformation is not None: | |
| deformation = deformation.to(x.device) | |
| return field, deformation | |
| assert self.isosurface_helper is not None | |
| field, deformation = chunk_batch( | |
| batch_func, | |
| self.cfg.isosurface_chunk, | |
| self.isosurface_helper.grid_vertices, | |
| ) | |
| threshold: float | |
| if isinstance(self.cfg.isosurface_threshold, float): | |
| threshold = self.cfg.isosurface_threshold | |
| elif self.cfg.isosurface_threshold == "auto": | |
| eps = 1.0e-5 | |
| threshold = field[field > eps].mean().item() | |
| craftsman.info( | |
| f"Automatically determined isosurface threshold: {threshold}" | |
| ) | |
| else: | |
| raise TypeError( | |
| f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}" | |
| ) | |
| level = self.forward_level(field, threshold) | |
| mesh: Mesh = self.isosurface_helper(level, deformation=deformation) | |
| mesh.v_pos = scale_tensor( | |
| mesh.v_pos, self.isosurface_helper.points_range, bbox | |
| ) # scale to bbox as the grid vertices are in [0, 1] | |
| mesh.add_extra("bbox", bbox) | |
| if self.cfg.isosurface_remove_outliers: | |
| # remove outliers components with small number of faces | |
| # only enabled when the mesh is not differentiable | |
| mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) | |
| return mesh | |
| def isosurface(self) -> Mesh: | |
| if not self.cfg.isosurface: | |
| raise NotImplementedError( | |
| "Isosurface is not enabled in the current configuration" | |
| ) | |
| self._initilize_isosurface_helper() | |
| if self.cfg.isosurface_coarse_to_fine: | |
| craftsman.debug("First run isosurface to get a tight bounding box ...") | |
| with torch.no_grad(): | |
| mesh_coarse = self._isosurface(self.bbox) | |
| vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0) | |
| vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0]) | |
| vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1]) | |
| craftsman.debug("Run isosurface again with the tight bounding box ...") | |
| mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True) | |
| else: | |
| mesh = self._isosurface(self.bbox) | |
| return mesh | |
| class BaseExplicitGeometry(BaseGeometry): | |
| class Config(BaseGeometry.Config): | |
| radius: float = 1.0 | |
| cfg: Config | |
| def configure(self) -> None: | |
| self.bbox: Float[Tensor, "2 3"] | |
| self.register_buffer( | |
| "bbox", | |
| torch.as_tensor( | |
| [ | |
| [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], | |
| [self.cfg.radius, self.cfg.radius, self.cfg.radius], | |
| ], | |
| dtype=torch.float32, | |
| ), | |
| ) |