Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import glob | |
| import json | |
| import tqdm | |
| import random | |
| import numpy as np | |
| from scipy.spatial.transform import Slerp, Rotation | |
| import trimesh | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from .utils import get_rays, safe_normalize | |
| def visualize_poses(poses, size=0.1): | |
| # poses: [B, 4, 4] | |
| axes = trimesh.creation.axis(axis_length=4) | |
| sphere = trimesh.creation.icosphere(radius=1) | |
| objects = [axes, sphere] | |
| for pose in poses: | |
| # a camera is visualized with 8 line segments. | |
| pos = pose[:3, 3] | |
| a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] | |
| b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] | |
| c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] | |
| d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] | |
| segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]]) | |
| segs = trimesh.load_path(segs) | |
| objects.append(segs) | |
| trimesh.Scene(objects).show() | |
| def get_view_direction(thetas, phis, overhead, front): | |
| # phis [B,]; thetas: [B,] | |
| # front = 0 [0, front) | |
| # side (left) = 1 [front, 180) | |
| # back = 2 [180, 180+front) | |
| # side (right) = 3 [180+front, 360) | |
| # top = 4 [0, overhead] | |
| # bottom = 5 [180-overhead, 180] | |
| res = torch.zeros(thetas.shape[0], dtype=torch.long) | |
| # first determine by phis | |
| res[(phis < front)] = 0 | |
| res[(phis >= front) & (phis < np.pi)] = 1 | |
| res[(phis >= np.pi) & (phis < (np.pi + front))] = 2 | |
| res[(phis >= (np.pi + front))] = 3 | |
| # override by thetas | |
| res[thetas <= overhead] = 4 | |
| res[thetas >= (np.pi - overhead)] = 5 | |
| return res | |
| def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False): | |
| ''' generate random poses from an orbit camera | |
| Args: | |
| size: batch size of generated poses. | |
| device: where to allocate the output. | |
| radius: camera radius | |
| theta_range: [min, max], should be in [0, pi] | |
| phi_range: [min, max], should be in [0, 2 * pi] | |
| Return: | |
| poses: [size, 4, 4] | |
| ''' | |
| theta_range = np.deg2rad(theta_range) | |
| phi_range = np.deg2rad(phi_range) | |
| angle_overhead = np.deg2rad(angle_overhead) | |
| angle_front = np.deg2rad(angle_front) | |
| radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] | |
| thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] | |
| phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] | |
| centers = torch.stack([ | |
| radius * torch.sin(thetas) * torch.sin(phis), | |
| radius * torch.cos(thetas), | |
| radius * torch.sin(thetas) * torch.cos(phis), | |
| ], dim=-1) # [B, 3] | |
| targets = 0 | |
| # jitters | |
| if jitter: | |
| centers = centers + (torch.rand_like(centers) * 0.2 - 0.1) | |
| targets = targets + torch.randn_like(centers) * 0.2 | |
| # lookat | |
| forward_vector = safe_normalize(targets - centers) | |
| up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) | |
| right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) | |
| if jitter: | |
| up_noise = torch.randn_like(up_vector) * 0.02 | |
| else: | |
| up_noise = 0 | |
| up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise) | |
| poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) | |
| poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) | |
| poses[:, :3, 3] = centers | |
| if return_dirs: | |
| dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) | |
| else: | |
| dirs = None | |
| return poses, dirs | |
| def circle_poses(device, radius=1.25, theta=60, phi=0, return_dirs=False, angle_overhead=30, angle_front=60): | |
| theta = np.deg2rad(theta) | |
| phi = np.deg2rad(phi) | |
| angle_overhead = np.deg2rad(angle_overhead) | |
| angle_front = np.deg2rad(angle_front) | |
| thetas = torch.FloatTensor([theta]).to(device) | |
| phis = torch.FloatTensor([phi]).to(device) | |
| centers = torch.stack([ | |
| radius * torch.sin(thetas) * torch.sin(phis), | |
| radius * torch.cos(thetas), | |
| radius * torch.sin(thetas) * torch.cos(phis), | |
| ], dim=-1) # [B, 3] | |
| # lookat | |
| forward_vector = - safe_normalize(centers) | |
| up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0) | |
| right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) | |
| up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1)) | |
| poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0) | |
| poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) | |
| poses[:, :3, 3] = centers | |
| if return_dirs: | |
| dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) | |
| else: | |
| dirs = None | |
| return poses, dirs | |
| class NeRFDataset: | |
| def __init__(self, opt, device, type='train', H=256, W=256, size=100): | |
| super().__init__() | |
| self.opt = opt | |
| self.device = device | |
| self.type = type # train, val, test | |
| self.H = H | |
| self.W = W | |
| self.radius_range = opt.radius_range | |
| self.fovy_range = opt.fovy_range | |
| self.size = size | |
| self.training = self.type in ['train', 'all'] | |
| self.cx = self.H / 2 | |
| self.cy = self.W / 2 | |
| # [debug] visualize poses | |
| # poses, dirs = rand_poses(100, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range) | |
| # visualize_poses(poses.detach().cpu().numpy()) | |
| def collate(self, index): | |
| B = len(index) # always 1 | |
| if self.training: | |
| # random pose on the fly | |
| poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose) | |
| # random focal | |
| fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0] | |
| focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2)) | |
| intrinsics = np.array([focal, focal, self.cx, self.cy]) | |
| else: | |
| # circle pose | |
| phi = (index[0] / self.size) * 360 | |
| poses, dirs = circle_poses(self.device, radius=self.radius_range[1] * 1.2, theta=60, phi=phi, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front) | |
| # fixed focal | |
| fov = (self.fovy_range[1] + self.fovy_range[0]) / 2 | |
| focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2)) | |
| intrinsics = np.array([focal, focal, self.cx, self.cy]) | |
| # sample a low-resolution but full image for CLIP | |
| rays = get_rays(poses, intrinsics, self.H, self.W, -1) | |
| data = { | |
| 'H': self.H, | |
| 'W': self.W, | |
| 'rays_o': rays['rays_o'], | |
| 'rays_d': rays['rays_d'], | |
| 'dir': dirs, | |
| } | |
| return data | |
| def dataloader(self): | |
| loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) | |
| loader._data = self # an ugly fix... we need to access dataset in trainer. | |
| return loader |