Spaces:
Runtime error
Runtime error
| import csv | |
| import copy | |
| import torch | |
| import einops | |
| import numpy as np | |
| from torch import nn | |
| import torch.nn.functional as F | |
| def get_activation_fn(activation_type): | |
| if activation_type not in ["relu", "gelu", "glu"]: | |
| raise RuntimeError(f"activation function currently support relu/gelu, not {activation_type}") | |
| return getattr(F, activation_type) | |
| def get_mlp_head(input_size, hidden_size, output_size, dropout=0): | |
| return nn.Sequential(*[ | |
| nn.Linear(input_size, hidden_size), | |
| nn.ReLU(), | |
| nn.LayerNorm(hidden_size, eps=1e-12), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size, output_size) | |
| ]) | |
| def layer_repeat(module, N, share_layer=False): | |
| if share_layer: | |
| return nn.ModuleList([module] * N) | |
| else: | |
| return nn.ModuleList([copy.deepcopy(module) for _ in range(N - 1)] + [module]) | |
| def calc_pairwise_locs(obj_centers, obj_whls, eps=1e-10, pairwise_rel_type='center', spatial_dist_norm=True, | |
| spatial_dim=5): | |
| if pairwise_rel_type == 'mlp': | |
| obj_locs = torch.cat([obj_centers, obj_whls], 2) | |
| pairwise_locs = torch.cat( | |
| [einops.repeat(obj_locs, 'b l d -> b l x d', x=obj_locs.size(1)), | |
| einops.repeat(obj_locs, 'b l d -> b x l d', x=obj_locs.size(1))], | |
| dim=3 | |
| ) | |
| return pairwise_locs | |
| pairwise_locs = einops.repeat(obj_centers, 'b l d -> b l 1 d') \ | |
| - einops.repeat(obj_centers, 'b l d -> b 1 l d') | |
| pairwise_dists = torch.sqrt(torch.sum(pairwise_locs ** 2, 3) + eps) # (b, l, l) | |
| if spatial_dist_norm: | |
| max_dists = torch.max(pairwise_dists.view(pairwise_dists.size(0), -1), dim=1)[0] | |
| norm_pairwise_dists = pairwise_dists / einops.repeat(max_dists, 'b -> b 1 1') | |
| else: | |
| norm_pairwise_dists = pairwise_dists | |
| if spatial_dim == 1: | |
| return norm_pairwise_dists.unsqueeze(3) | |
| pairwise_dists_2d = torch.sqrt(torch.sum(pairwise_locs[..., :2] ** 2, 3) + eps) | |
| if pairwise_rel_type == 'center': | |
| pairwise_locs = torch.stack( | |
| [norm_pairwise_dists, pairwise_locs[..., 2] / pairwise_dists, | |
| pairwise_dists_2d / pairwise_dists, pairwise_locs[..., 1] / pairwise_dists_2d, | |
| pairwise_locs[..., 0] / pairwise_dists_2d], | |
| dim=3 | |
| ) | |
| elif pairwise_rel_type == 'vertical_bottom': | |
| bottom_centers = torch.clone(obj_centers) | |
| bottom_centers[:, :, 2] -= obj_whls[:, :, 2] | |
| bottom_pairwise_locs = einops.repeat(bottom_centers, 'b l d -> b l 1 d') \ | |
| - einops.repeat(bottom_centers, 'b l d -> b 1 l d') | |
| bottom_pairwise_dists = torch.sqrt(torch.sum(bottom_pairwise_locs ** 2, 3) + eps) # (b, l, l) | |
| bottom_pairwise_dists_2d = torch.sqrt(torch.sum(bottom_pairwise_locs[..., :2] ** 2, 3) + eps) | |
| pairwise_locs = torch.stack( | |
| [norm_pairwise_dists, | |
| bottom_pairwise_locs[..., 2] / bottom_pairwise_dists, | |
| bottom_pairwise_dists_2d / bottom_pairwise_dists, | |
| pairwise_locs[..., 1] / pairwise_dists_2d, | |
| pairwise_locs[..., 0] / pairwise_dists_2d], | |
| dim=3 | |
| ) | |
| if spatial_dim == 4: | |
| pairwise_locs = pairwise_locs[..., 1:] | |
| return pairwise_locs | |
| def convert_pc_to_box(obj_pc): | |
| xmin = np.min(obj_pc[:,0]) | |
| ymin = np.min(obj_pc[:,1]) | |
| zmin = np.min(obj_pc[:,2]) | |
| xmax = np.max(obj_pc[:,0]) | |
| ymax = np.max(obj_pc[:,1]) | |
| zmax = np.max(obj_pc[:,2]) | |
| center = [(xmin+xmax)/2, (ymin+ymax)/2, (zmin+zmax)/2] | |
| box_size = [xmax-xmin, ymax-ymin, zmax-zmin] | |
| return center, box_size | |
| class LabelConverter(object): | |
| def __init__(self, file_path): | |
| self.raw_name_to_id = {} | |
| self.nyu40id_to_id = {} | |
| self.nyu40_name_to_id = {} | |
| self.scannet_name_to_scannet_id = {'cabinet':0, 'bed':1, 'chair':2, 'sofa':3, 'table':4, | |
| 'door':5, 'window':6,'bookshelf':7,'picture':8, 'counter':9, 'desk':10, 'curtain':11, | |
| 'refrigerator':12, 'shower curtain':13, 'toilet':14, 'sink':15, 'bathtub':16, 'others':17} | |
| self.id_to_scannetid = {} | |
| self.scannet_raw_id_to_raw_name = {} | |
| self.raw_name_to_scannet_raw_id = {} | |
| with open(file_path, encoding='utf-8') as fd: | |
| rd = list(csv.reader(fd, delimiter="\t", quotechar='"')) | |
| for i in range(1, len(rd)): | |
| raw_id = i - 1 | |
| scannet_raw_id = int(rd[i][0]) | |
| raw_name = rd[i][1] | |
| nyu40_id = int(rd[i][4]) | |
| nyu40_name = rd[i][7] | |
| self.raw_name_to_id[raw_name] = raw_id | |
| self.scannet_raw_id_to_raw_name[scannet_raw_id] = raw_name | |
| self.raw_name_to_scannet_raw_id[raw_name] = scannet_raw_id | |
| self.nyu40id_to_id[nyu40_id] = raw_id | |
| self.nyu40_name_to_id[nyu40_name] = raw_id | |
| if nyu40_name not in self.scannet_name_to_scannet_id: | |
| self.id_to_scannetid[raw_id] = self.scannet_name_to_scannet_id['others'] | |
| else: | |
| self.id_to_scannetid[raw_id] = self.scannet_name_to_scannet_id[nyu40_name] | |
| def build_rotate_mat(split, rot_aug=True, rand_angle='axis'): | |
| if rand_angle == 'random': | |
| theta = np.random.rand() * np.pi * 2 | |
| else: | |
| ROTATE_ANGLES = [0, np.pi/2, np.pi, np.pi*3/2] | |
| theta_idx = np.random.randint(len(ROTATE_ANGLES)) | |
| theta = ROTATE_ANGLES[theta_idx] | |
| if (theta is not None) and (theta != 0) and (split == 'train') and rot_aug: | |
| rot_matrix = np.array([ | |
| [np.cos(theta), -np.sin(theta), 0], | |
| [np.sin(theta), np.cos(theta), 0], | |
| [0, 0, 1] | |
| ], dtype=np.float32) | |
| else: | |
| rot_matrix = None | |
| return rot_matrix | |
| def obj_processing_post(obj_pcds, rot_aug=True): | |
| obj_pcds = torch.from_numpy(obj_pcds) | |
| rot_matrix = build_rotate_mat('val', rot_aug) | |
| if rot_matrix is not None: | |
| rot_matrix = torch.from_numpy(rot_matrix.transpose()) | |
| obj_pcds[:, :, :3] @= rot_matrix | |
| xyz = obj_pcds[:, :, :3] | |
| center = xyz.mean(1) | |
| xyz_min = xyz.min(1).values | |
| xyz_max = xyz.max(1).values | |
| box_center = (xyz_min + xyz_max) / 2 | |
| size = xyz_max - xyz_min | |
| obj_locs = torch.cat([center, size], dim=1) | |
| obj_boxes = torch.cat([box_center, size], dim=1) | |
| # centering | |
| obj_pcds[:, :, :3].sub_(obj_pcds[:, :, :3].mean(1, keepdim=True)) | |
| # normalization | |
| max_dist = (obj_pcds[:, :, :3]**2).sum(2).sqrt().max(1).values | |
| max_dist.clamp_(min=1e-6) | |
| obj_pcds[:, :, :3].div_(max_dist[:, None, None]) | |
| return obj_pcds, obj_locs, obj_boxes, rot_matrix | |
| def pad_sequence(sequence_list, max_len=None, pad=0, return_mask=False): | |
| lens = [x.shape[0] for x in sequence_list] | |
| if max_len is None: | |
| max_len = max(lens) | |
| shape = list(sequence_list[0].shape) | |
| shape[0] = max_len | |
| shape = [len(sequence_list)] + shape | |
| dtype = sequence_list[0].dtype | |
| device = sequence_list[0].device | |
| padded_sequence = torch.ones(shape, dtype=dtype, device=device) * pad | |
| for i, tensor in enumerate(sequence_list): | |
| padded_sequence[i, :tensor.shape[0]] = tensor | |
| padded_sequence = padded_sequence.to(dtype) | |
| if return_mask: | |
| mask = torch.arange(max_len).to(device)[None, :] >= torch.LongTensor(lens).to(device)[:, None] # True as masked. | |
| return padded_sequence, mask | |
| else: | |
| return padded_sequence |