Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d | |
| from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d | |
| import einops | |
| from modules.util import UpBlock2d, DownBlock2d | |
| def make_coordinate_grid(spatial_size, type): | |
| d, h, w = spatial_size | |
| x = torch.arange(w).type(type) | |
| y = torch.arange(h).type(type) | |
| z = torch.arange(d).type(type) | |
| x = (2 * (x / (w - 1)) - 1) | |
| y = (2 * (y / (h - 1)) - 1) | |
| z = (2 * (z / (d - 1)) - 1) | |
| yy = y.view(1, -1, 1).repeat(d, 1, w) | |
| xx = x.view(1, 1, -1).repeat(d, h, 1) | |
| zz = z.view(-1, 1, 1).repeat(1, h, w) | |
| meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) | |
| return meshed | |
| def kp2gaussian_3d(kp, spatial_size, kp_variance): | |
| """ | |
| Transform a keypoint into gaussian like representation | |
| """ | |
| # mean = kp['value'] | |
| mean = kp | |
| coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) | |
| number_of_leading_dimensions = len(mean.shape) - 1 | |
| shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape | |
| coordinate_grid = coordinate_grid.view(*shape) | |
| repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) | |
| coordinate_grid = coordinate_grid.repeat(*repeats) | |
| # Preprocess kp shape | |
| shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) | |
| mean = mean.view(*shape) | |
| mean_sub = (coordinate_grid - mean) | |
| out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) | |
| return out | |
| class ResBlock3d(nn.Module): | |
| """ | |
| Res block, preserve spatial resolution. | |
| """ | |
| def __init__(self, in_features, kernel_size, padding): | |
| super(ResBlock3d, self).__init__() | |
| self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, | |
| padding=padding) | |
| self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, | |
| padding=padding) | |
| self.norm1 = BatchNorm3d(in_features, affine=True) | |
| self.norm2 = BatchNorm3d(in_features, affine=True) | |
| def forward(self, x): | |
| out = self.norm1(x) | |
| out = F.relu(out) | |
| out = self.conv1(out) | |
| out = self.norm2(out) | |
| out = F.relu(out) | |
| out = self.conv2(out) | |
| out += x | |
| return out | |
| class rgb_predictor(nn.Module): | |
| def __init__(self, in_channels, simpled_channel=128, floor_num=8): | |
| super(rgb_predictor, self).__init__() | |
| self.floor_num = floor_num | |
| self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1) | |
| def forward(self, feature): | |
| """ | |
| Args: | |
| feature: warp feature: bs * c * h * w | |
| Returns: | |
| rgb: bs * h * w * floor_num * e | |
| """ | |
| feature = self.down_conv(feature) | |
| feature = einops.rearrange(feature, 'b (c f) h w -> b c f h w', f=self.floor_num) | |
| feature = einops.rearrange(feature, 'b c f h w -> b h w f c') | |
| return feature | |
| class sigma_predictor(nn.Module): | |
| def __init__(self, in_channels, simpled_channel=128, floor_num=8): | |
| super(sigma_predictor, self).__init__() | |
| self.floor_num = floor_num | |
| self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1) | |
| self.res_conv3d = nn.Sequential( | |
| ResBlock3d(16, 3, 1), | |
| nn.BatchNorm3d(16), | |
| ResBlock3d(16, 3, 1), | |
| nn.BatchNorm3d(16), | |
| ResBlock3d(16, 3, 1), | |
| nn.BatchNorm3d(16) | |
| ) | |
| def forward(self, feature): | |
| """ | |
| Args: | |
| feature: bs * h * w * floor * c, the output of rgb predictor | |
| Returns: | |
| sigma: bs * h * w * floor * encode | |
| point: bs * 5023 * 3 | |
| """ | |
| heatmap = self.down_conv(feature) | |
| heatmap = einops.rearrange(heatmap, "b (c f) h w -> b c f h w", f=self.floor_num) | |
| heatmap = self.res_conv3d(heatmap) | |
| sigma = einops.rearrange(heatmap, "b c f h w -> b h w f c") | |
| point_dict = {'sigma_map': heatmap} | |
| # point_pred = einops.rearrange(point_pred, 'b p n -> b n p') | |
| return sigma, point_dict | |
| class MultiHeadNeRFModel(torch.nn.Module): | |
| def __init__(self, hidden_size=128, num_encoding_rgb=16, num_encoding_sigma=16): | |
| super(MultiHeadNeRFModel, self).__init__() | |
| # self.xyz_encoding_dims = 1 + 1 * 2 * num_encoding_functions + num_encoding_rgb | |
| self.xyz_encoding_dims = num_encoding_sigma | |
| self.viewdir_encoding_dims = num_encoding_rgb | |
| # Input layer (default: 16 -> 128) | |
| self.layer1 = torch.nn.Linear(self.xyz_encoding_dims, hidden_size) | |
| # Layer 2 (default: 128 -> 128) | |
| self.layer2 = torch.nn.Linear(hidden_size, hidden_size) | |
| # Layer 3_1 (default: 128 -> 1): Predicts radiance ("sigma") | |
| self.layer3_1 = torch.nn.Linear(hidden_size, 1) | |
| # Layer 3_2 (default: 128 -> 32): Predicts a feature vector (used for color) | |
| self.layer3_2 = torch.nn.Linear(hidden_size, hidden_size // 4) | |
| self.layer3_3 = torch.nn.Linear(self.viewdir_encoding_dims, hidden_size) | |
| # Layer 4 (default: 32 + 128 -> 128) | |
| self.layer4 = torch.nn.Linear( | |
| hidden_size // 4 + hidden_size, hidden_size | |
| ) | |
| # Layer 5 (default: 128 -> 128) | |
| self.layer5 = torch.nn.Linear(hidden_size, hidden_size) | |
| # Layer 6 (default: 128 -> 256): Predicts RGB color | |
| self.layer6 = torch.nn.Linear(hidden_size, 256) | |
| # Short hand for torch.nn.functional.relu | |
| self.relu = torch.nn.functional.relu | |
| def forward(self, rgb_in, sigma_in): | |
| """ | |
| Args: | |
| x: rgb pred result of Perdict3D | |
| view: result of LightPredict | |
| Returns: | |
| """ | |
| bs, h, w, floor_num, _ = rgb_in.size() | |
| # x = torch.cat((x, point3D), dim=-1) | |
| out = self.relu(self.layer1(sigma_in)) | |
| out = self.relu(self.layer2(out)) | |
| sigma = self.layer3_1(out) | |
| feat_sigma = self.relu(self.layer3_2(out)) | |
| feat_rgb = self.relu(self.layer3_3(rgb_in)) | |
| x = torch.cat((feat_sigma, feat_rgb), dim=-1) | |
| x = self.relu(self.layer4(x)) | |
| x = self.relu(self.layer5(x)) | |
| x = self.layer6(x) | |
| return x, sigma | |
| def volume_render(rgb_pred, sigma_pred): | |
| """ | |
| Args: | |
| rgb_pred: result of Nerf, [bs, h, w, floor, rgb_channel] | |
| sigma_pred: result of Nerf, [bs, h, w, floor, sigma_channel] | |
| Returns: | |
| """ | |
| _, _, _, floor, _ = sigma_pred.size() | |
| c = 0 | |
| T = 0 | |
| for i in range(floor): | |
| sigma_mid = torch.nn.functional.relu(sigma_pred[:, :, :, i, :]) | |
| T = T + (-sigma_mid) | |
| c = c + torch.exp(T) * (1 - torch.exp(-sigma_mid)) * rgb_pred[:, :, :, i, :] | |
| c = einops.rearrange(c, 'b h w c -> b c h w') | |
| return c | |
| class RenderModel(nn.Module): | |
| def __init__(self, in_channels, simpled_channel_rgb, simpled_channel_sigma, floor_num, hidden_size): | |
| super(RenderModel, self).__init__() | |
| self.rgb_predict = rgb_predictor(in_channels=in_channels, simpled_channel=simpled_channel_rgb, | |
| floor_num=floor_num) | |
| self.sigma_predict = sigma_predictor(in_channels=in_channels, simpled_channel=simpled_channel_sigma, | |
| floor_num=floor_num) | |
| num_encoding_rgb, num_encoding_sigma = simpled_channel_rgb // floor_num, simpled_channel_sigma // floor_num | |
| self.nerf_module = MultiHeadNeRFModel(hidden_size=hidden_size, num_encoding_rgb=num_encoding_rgb, | |
| num_encoding_sigma=num_encoding_sigma) | |
| self.mini_decoder = nn.Sequential( | |
| UpBlock2d(256, 64, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| UpBlock2d(64, 3, kernel_size=3, padding=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, feature): | |
| rgb_in = self.rgb_predict(feature) | |
| # sigma_in, point_dict = self.sigma_predict(feature.detach()) | |
| sigma_in, point_dict = self.sigma_predict(feature) | |
| rgb_out, sigma_out = self.nerf_module(rgb_in, sigma_in) | |
| render_result = volume_render(rgb_out, sigma_out) | |
| render_result = torch.sigmoid(render_result) | |
| mini_pred = self.mini_decoder(render_result) | |
| out_dict = {'render': render_result, 'mini_pred': mini_pred, 'point_pred': point_dict} | |
| return out_dict | |