Spaces:
Runtime error
Runtime error
| import os | |
| import math | |
| import cv2 | |
| import trimesh | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import mcubes | |
| import raymarching | |
| from .utils import custom_meshgrid, safe_normalize | |
| def sample_pdf(bins, weights, n_samples, det=False): | |
| # This implementation is from NeRF | |
| # bins: [B, T], old_z_vals | |
| # weights: [B, T - 1], bin weights. | |
| # return: [B, n_samples], new_z_vals | |
| # Get pdf | |
| weights = weights + 1e-5 # prevent nans | |
| pdf = weights / torch.sum(weights, -1, keepdim=True) | |
| cdf = torch.cumsum(pdf, -1) | |
| cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) | |
| # Take uniform samples | |
| if det: | |
| u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) | |
| u = u.expand(list(cdf.shape[:-1]) + [n_samples]) | |
| else: | |
| u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) | |
| # Invert CDF | |
| u = u.contiguous() | |
| inds = torch.searchsorted(cdf, u, right=True) | |
| below = torch.max(torch.zeros_like(inds - 1), inds - 1) | |
| above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) | |
| inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) | |
| matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] | |
| cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) | |
| bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) | |
| denom = (cdf_g[..., 1] - cdf_g[..., 0]) | |
| denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) | |
| t = (u - cdf_g[..., 0]) / denom | |
| samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) | |
| return samples | |
| def plot_pointcloud(pc, color=None): | |
| # pc: [N, 3] | |
| # color: [N, 3/4] | |
| print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) | |
| pc = trimesh.PointCloud(pc, color) | |
| # axis | |
| axes = trimesh.creation.axis(axis_length=4) | |
| # sphere | |
| sphere = trimesh.creation.icosphere(radius=1) | |
| trimesh.Scene([pc, axes, sphere]).show() | |
| class NeRFRenderer(nn.Module): | |
| def __init__(self, opt): | |
| super().__init__() | |
| self.opt = opt | |
| self.bound = opt.bound | |
| self.cascade = 1 + math.ceil(math.log2(opt.bound)) | |
| self.grid_size = 128 | |
| self.cuda_ray = opt.cuda_ray | |
| self.min_near = opt.min_near | |
| self.density_thresh = opt.density_thresh | |
| self.bg_radius = opt.bg_radius | |
| # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) | |
| # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. | |
| aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound]) | |
| aabb_infer = aabb_train.clone() | |
| self.register_buffer('aabb_train', aabb_train) | |
| self.register_buffer('aabb_infer', aabb_infer) | |
| # extra state for cuda raymarching | |
| if self.cuda_ray: | |
| # density grid | |
| density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] | |
| density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] | |
| self.register_buffer('density_grid', density_grid) | |
| self.register_buffer('density_bitfield', density_bitfield) | |
| self.mean_density = 0 | |
| self.iter_density = 0 | |
| # step counter | |
| step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... | |
| self.register_buffer('step_counter', step_counter) | |
| self.mean_count = 0 | |
| self.local_step = 0 | |
| def forward(self, x, d): | |
| raise NotImplementedError() | |
| def density(self, x): | |
| raise NotImplementedError() | |
| def color(self, x, d, mask=None, **kwargs): | |
| raise NotImplementedError() | |
| def reset_extra_state(self): | |
| if not self.cuda_ray: | |
| return | |
| # density grid | |
| self.density_grid.zero_() | |
| self.mean_density = 0 | |
| self.iter_density = 0 | |
| # step counter | |
| self.step_counter.zero_() | |
| self.mean_count = 0 | |
| self.local_step = 0 | |
| def export_mesh(self, path, resolution=None, S=128): | |
| if resolution is None: | |
| resolution = self.grid_size | |
| density_thresh = min(self.mean_density, self.density_thresh) | |
| sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32) | |
| # query | |
| X = torch.linspace(-1, 1, resolution).split(S) | |
| Y = torch.linspace(-1, 1, resolution).split(S) | |
| Z = torch.linspace(-1, 1, resolution).split(S) | |
| for xi, xs in enumerate(X): | |
| for yi, ys in enumerate(Y): | |
| for zi, zs in enumerate(Z): | |
| xx, yy, zz = custom_meshgrid(xs, ys, zs) | |
| pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] | |
| val = self.density(pts.to(self.density_bitfield.device)) | |
| sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] | |
| vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh) | |
| vertices = vertices / (resolution - 1.0) * 2 - 1 | |
| vertices = vertices.astype(np.float32) | |
| triangles = triangles.astype(np.int32) | |
| v = torch.from_numpy(vertices).to(self.density_bitfield.device) | |
| f = torch.from_numpy(triangles).int().to(self.density_bitfield.device) | |
| # mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... | |
| # mesh.export(os.path.join(path, f'mesh.ply')) | |
| # texture? | |
| def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''): | |
| # v, f: torch Tensor | |
| device = v.device | |
| v_np = v.cpu().numpy() # [N, 3] | |
| f_np = f.cpu().numpy() # [M, 3] | |
| print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') | |
| # unwrap uvs | |
| import xatlas | |
| import nvdiffrast.torch as dr | |
| from sklearn.neighbors import NearestNeighbors | |
| from scipy.ndimage import binary_dilation, binary_erosion | |
| glctx = dr.RasterizeCudaContext() | |
| atlas = xatlas.Atlas() | |
| atlas.add_mesh(v_np, f_np) | |
| chart_options = xatlas.ChartOptions() | |
| chart_options.max_iterations = 0 # disable merge_chart for faster unwrap... | |
| atlas.generate(chart_options=chart_options) | |
| vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] | |
| # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2] | |
| vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device) | |
| ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device) | |
| # render uv maps | |
| uv = vt * 2.0 - 1.0 # uvs to range [-1, 1] | |
| uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] | |
| if ssaa > 1: | |
| h = int(h0 * ssaa) | |
| w = int(w0 * ssaa) | |
| else: | |
| h, w = h0, w0 | |
| rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4] | |
| xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3] | |
| mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1] | |
| # masked query | |
| xyzs = xyzs.view(-1, 3) | |
| mask = (mask > 0).view(-1) | |
| sigmas = torch.zeros(h * w, device=device, dtype=torch.float32) | |
| feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32) | |
| if mask.any(): | |
| xyzs = xyzs[mask] # [M, 3] | |
| # batched inference to avoid OOM | |
| all_sigmas = [] | |
| all_feats = [] | |
| head = 0 | |
| while head < xyzs.shape[0]: | |
| tail = min(head + 640000, xyzs.shape[0]) | |
| results_ = self.density(xyzs[head:tail]) | |
| all_sigmas.append(results_['sigma'].float()) | |
| all_feats.append(results_['albedo'].float()) | |
| head += 640000 | |
| sigmas[mask] = torch.cat(all_sigmas, dim=0) | |
| feats[mask] = torch.cat(all_feats, dim=0) | |
| sigmas = sigmas.view(h, w, 1) | |
| feats = feats.view(h, w, -1) | |
| mask = mask.view(h, w) | |
| ### alpha mask | |
| # deltas = 2 * np.sqrt(3) / 1024 | |
| # alphas = 1 - torch.exp(-sigmas * deltas) | |
| # alphas_mask = alphas > 0.5 | |
| # feats = feats * alphas_mask | |
| # quantize [0.0, 1.0] to [0, 255] | |
| feats = feats.cpu().numpy() | |
| feats = (feats * 255).astype(np.uint8) | |
| # alphas = alphas.cpu().numpy() | |
| # alphas = (alphas * 255).astype(np.uint8) | |
| ### NN search as an antialiasing ... | |
| mask = mask.cpu().numpy() | |
| inpaint_region = binary_dilation(mask, iterations=3) | |
| inpaint_region[mask] = 0 | |
| search_region = mask.copy() | |
| not_search_region = binary_erosion(search_region, iterations=2) | |
| search_region[not_search_region] = 0 | |
| search_coords = np.stack(np.nonzero(search_region), axis=-1) | |
| inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) | |
| knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords) | |
| _, indices = knn.kneighbors(inpaint_coords) | |
| feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)] | |
| # do ssaa after the NN search, in numpy | |
| feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR) | |
| if ssaa > 1: | |
| # alphas = cv2.resize(alphas, (w0, h0), interpolation=cv2.INTER_NEAREST) | |
| feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR) | |
| # cv2.imwrite(os.path.join(path, f'alpha.png'), alphas) | |
| cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats) | |
| # save obj (v, vt, f /) | |
| obj_file = os.path.join(path, f'{name}mesh.obj') | |
| mtl_file = os.path.join(path, f'{name}mesh.mtl') | |
| print(f'[INFO] writing obj mesh to {obj_file}') | |
| with open(obj_file, "w") as fp: | |
| fp.write(f'mtllib {name}mesh.mtl \n') | |
| print(f'[INFO] writing vertices {v_np.shape}') | |
| for v in v_np: | |
| fp.write(f'v {v[0]} {v[1]} {v[2]} \n') | |
| print(f'[INFO] writing vertices texture coords {vt_np.shape}') | |
| for v in vt_np: | |
| fp.write(f'vt {v[0]} {1 - v[1]} \n') | |
| print(f'[INFO] writing faces {f_np.shape}') | |
| fp.write(f'usemtl mat0 \n') | |
| for i in range(len(f_np)): | |
| fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n") | |
| with open(mtl_file, "w") as fp: | |
| fp.write(f'newmtl mat0 \n') | |
| fp.write(f'Ka 1.000000 1.000000 1.000000 \n') | |
| fp.write(f'Kd 1.000000 1.000000 1.000000 \n') | |
| fp.write(f'Ks 0.000000 0.000000 0.000000 \n') | |
| fp.write(f'Tr 1.000000 \n') | |
| fp.write(f'illum 1 \n') | |
| fp.write(f'Ns 0.000000 \n') | |
| fp.write(f'map_Kd {name}albedo.png \n') | |
| _export(v, f) | |
| def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs): | |
| # rays_o, rays_d: [B, N, 3], assumes B == 1 | |
| # bg_color: [BN, 3] in range [0, 1] | |
| # return: image: [B, N, 3], depth: [B, N] | |
| prefix = rays_o.shape[:-1] | |
| rays_o = rays_o.contiguous().view(-1, 3) | |
| rays_d = rays_d.contiguous().view(-1, 3) | |
| N = rays_o.shape[0] # N = B * N, in fact | |
| device = rays_o.device | |
| results = {} | |
| # choose aabb | |
| aabb = self.aabb_train if self.training else self.aabb_infer | |
| # sample steps | |
| nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) | |
| nears.unsqueeze_(-1) | |
| fars.unsqueeze_(-1) | |
| # random sample light_d if not provided | |
| if light_d is None: | |
| # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) | |
| light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float)) | |
| light_d = safe_normalize(light_d) | |
| #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') | |
| z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] | |
| z_vals = z_vals.expand((N, num_steps)) # [N, T] | |
| z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] | |
| # perturb z_vals | |
| sample_dist = (fars - nears) / num_steps | |
| if perturb: | |
| z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist | |
| #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. | |
| # generate xyzs | |
| xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] | |
| xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. | |
| #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) | |
| # query SDF and RGB | |
| density_outputs = self.density(xyzs.reshape(-1, 3)) | |
| #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] | |
| for k, v in density_outputs.items(): | |
| density_outputs[k] = v.view(N, num_steps, -1) | |
| # upsample z_vals (nerf-like) | |
| if upsample_steps > 0: | |
| with torch.no_grad(): | |
| deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] | |
| deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) | |
| alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T] | |
| alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] | |
| weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] | |
| # sample new z_vals | |
| z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] | |
| new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] | |
| new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] | |
| new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. | |
| # only forward new points to save computation | |
| new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) | |
| #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] | |
| for k, v in new_density_outputs.items(): | |
| new_density_outputs[k] = v.view(N, upsample_steps, -1) | |
| # re-order | |
| z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] | |
| z_vals, z_index = torch.sort(z_vals, dim=1) | |
| xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] | |
| xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) | |
| for k in density_outputs: | |
| tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) | |
| density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) | |
| deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] | |
| deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) | |
| alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t] | |
| alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] | |
| weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] | |
| dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) | |
| for k, v in density_outputs.items(): | |
| density_outputs[k] = v.view(-1, v.shape[-1]) | |
| sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading) | |
| rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] | |
| #print(xyzs.shape, 'valid_rgb:', mask.sum().item()) | |
| # orientation loss | |
| if normals is not None: | |
| normals = normals.view(N, -1, 3) | |
| # print(weights.shape, normals.shape, dirs.shape) | |
| loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 | |
| results['loss_orient'] = loss_orient.mean() | |
| # calculate weight_sum (mask) | |
| weights_sum = weights.sum(dim=-1) # [N] | |
| # calculate depth | |
| ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) | |
| depth = torch.sum(weights * ori_z_vals, dim=-1) | |
| # calculate color | |
| image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] | |
| # mix background color | |
| if self.bg_radius > 0: | |
| # use the bg model to calculate bg_color | |
| # sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] | |
| bg_color = self.background(rays_d.reshape(-1, 3)) # [N, 3] | |
| elif bg_color is None: | |
| bg_color = 1 | |
| image = image + (1 - weights_sum).unsqueeze(-1) * bg_color | |
| image = image.view(*prefix, 3) | |
| depth = depth.view(*prefix) | |
| mask = (nears < fars).reshape(*prefix) | |
| results['image'] = image | |
| results['depth'] = depth | |
| results['weights_sum'] = weights_sum | |
| results['mask'] = mask | |
| return results | |
| def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): | |
| # rays_o, rays_d: [B, N, 3], assumes B == 1 | |
| # return: image: [B, N, 3], depth: [B, N] | |
| prefix = rays_o.shape[:-1] | |
| rays_o = rays_o.contiguous().view(-1, 3) | |
| rays_d = rays_d.contiguous().view(-1, 3) | |
| N = rays_o.shape[0] # N = B * N, in fact | |
| device = rays_o.device | |
| # pre-calculate near far | |
| nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer) | |
| # random sample light_d if not provided | |
| if light_d is None: | |
| # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) | |
| light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float)) | |
| light_d = safe_normalize(light_d) | |
| results = {} | |
| if self.training: | |
| # setup counter | |
| counter = self.step_counter[self.local_step % 16] | |
| counter.zero_() # set to 0 | |
| self.local_step += 1 | |
| xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) | |
| #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) | |
| sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) | |
| #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') | |
| weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh) | |
| # orientation loss | |
| if normals is not None: | |
| weights = 1 - torch.exp(-sigmas) | |
| loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 | |
| results['loss_orient'] = loss_orient.mean() | |
| else: | |
| # allocate outputs | |
| dtype = torch.float32 | |
| weights_sum = torch.zeros(N, dtype=dtype, device=device) | |
| depth = torch.zeros(N, dtype=dtype, device=device) | |
| image = torch.zeros(N, 3, dtype=dtype, device=device) | |
| n_alive = N | |
| rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] | |
| rays_t = nears.clone() # [N] | |
| step = 0 | |
| while step < max_steps: # hard coded max step | |
| # count alive rays | |
| n_alive = rays_alive.shape[0] | |
| # exit loop | |
| if n_alive <= 0: | |
| break | |
| # decide compact_steps | |
| n_step = max(min(N // n_alive, 8), 1) | |
| xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) | |
| sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) | |
| raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh) | |
| rays_alive = rays_alive[rays_alive >= 0] | |
| #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') | |
| step += n_step | |
| # mix background color | |
| if self.bg_radius > 0: | |
| # use the bg model to calculate bg_color | |
| # sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] | |
| bg_color = self.background(rays_d) # [N, 3] | |
| elif bg_color is None: | |
| bg_color = 1 | |
| image = image + (1 - weights_sum).unsqueeze(-1) * bg_color | |
| image = image.view(*prefix, 3) | |
| depth = torch.clamp(depth - nears, min=0) / (fars - nears) | |
| depth = depth.view(*prefix) | |
| weights_sum = weights_sum.reshape(*prefix) | |
| mask = (nears < fars).reshape(*prefix) | |
| results['image'] = image | |
| results['depth'] = depth | |
| results['weights_sum'] = weights_sum | |
| results['mask'] = mask | |
| return results | |
| def update_extra_state(self, decay=0.95, S=128): | |
| # call before each epoch to update extra states. | |
| if not self.cuda_ray: | |
| return | |
| ### update density grid | |
| tmp_grid = - torch.ones_like(self.density_grid) | |
| X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
| Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
| Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
| for xs in X: | |
| for ys in Y: | |
| for zs in Z: | |
| # construct points | |
| xx, yy, zz = custom_meshgrid(xs, ys, zs) | |
| coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) | |
| indices = raymarching.morton3D(coords).long() # [N] | |
| xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] | |
| # cascading | |
| for cas in range(self.cascade): | |
| bound = min(2 ** cas, self.bound) | |
| half_grid_size = bound / self.grid_size | |
| # scale to current cascade's resolution | |
| cas_xyzs = xyzs * (bound - half_grid_size) | |
| # add noise in [-hgs, hgs] | |
| cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size | |
| # query density | |
| sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() | |
| # assign | |
| tmp_grid[cas, indices] = sigmas | |
| # ema update | |
| valid_mask = self.density_grid >= 0 | |
| self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) | |
| self.mean_density = torch.mean(self.density_grid[valid_mask]).item() | |
| self.iter_density += 1 | |
| # convert to bitfield | |
| density_thresh = min(self.mean_density, self.density_thresh) | |
| self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) | |
| ### update step counter | |
| total_step = min(16, self.local_step) | |
| if total_step > 0: | |
| self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) | |
| self.local_step = 0 | |
| # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') | |
| def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): | |
| # rays_o, rays_d: [B, N, 3], assumes B == 1 | |
| # return: pred_rgb: [B, N, 3] | |
| if self.cuda_ray: | |
| _run = self.run_cuda | |
| else: | |
| _run = self.run | |
| B, N = rays_o.shape[:2] | |
| device = rays_o.device | |
| # never stage when cuda_ray | |
| if staged and not self.cuda_ray: | |
| depth = torch.empty((B, N), device=device) | |
| image = torch.empty((B, N, 3), device=device) | |
| weights_sum = torch.empty((B, N), device=device) | |
| for b in range(B): | |
| head = 0 | |
| while head < N: | |
| tail = min(head + max_ray_batch, N) | |
| results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs) | |
| depth[b:b+1, head:tail] = results_['depth'] | |
| weights_sum[b:b+1, head:tail] = results_['weights_sum'] | |
| image[b:b+1, head:tail] = results_['image'] | |
| head += max_ray_batch | |
| results = {} | |
| results['depth'] = depth | |
| results['image'] = image | |
| results['weights_sum'] = weights_sum | |
| else: | |
| results = _run(rays_o, rays_d, **kwargs) | |
| return results |