Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import matplotlib as mpl | |
| import matplotlib.cm as cm | |
| from vidar.core.wrapper import Wrapper | |
| from vidar.utils.config import read_config | |
| def colormap_depth(depth_map): | |
| # Input: depth_map -> HxW numpy array with depth values | |
| # Output: colormapped_im -> HxW numpy array with colorcoded depth values | |
| mask = depth_map!=0 | |
| disp_map = 1/depth_map | |
| vmax = np.percentile(disp_map[mask], 95) | |
| vmin = np.percentile(disp_map[mask], 5) | |
| normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax) | |
| mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') | |
| mask = np.repeat(np.expand_dims(mask,-1), 3, -1) | |
| colormapped_im = (mapper.to_rgba(disp_map)[:, :, :3] * 255).astype(np.uint8) | |
| colormapped_im[~mask] = 255 | |
| return colormapped_im | |
| def data_to_batch(data): | |
| batch = data.copy() | |
| batch['rgb'][0] = batch['rgb'][0].unsqueeze(0).unsqueeze(0) | |
| batch['rgb'][1] = batch['rgb'][1].unsqueeze(0).unsqueeze(0) | |
| batch['intrinsics'][0] = batch['intrinsics'][0].unsqueeze(0).unsqueeze(0) | |
| batch['pose'][0] = batch['pose'][0].unsqueeze(0).unsqueeze(0) | |
| batch['pose'][1] = batch['pose'][1].unsqueeze(0).unsqueeze(0) | |
| batch['depth'][0] = batch['depth'][0].unsqueeze(0).unsqueeze(0) | |
| batch['depth'][1] = batch['depth'][1].unsqueeze(0).unsqueeze(0) | |
| return batch | |
| os.environ['DIST_MODE'] = 'gpu' if torch.cuda.is_available() else 'cpu' | |
| cfg_file_path = 'configs/papers/define/scannet_temporal_test_context_1.yaml' | |
| cfg = read_config(cfg_file_path) | |
| wrapper = Wrapper(cfg, verbose=True) | |
| # print('arch: ', wrapper.arch) | |
| # print('datasets: ', wrapper.datasets) | |
| arch = wrapper.arch | |
| arch.eval() | |
| val_dataset = wrapper.datasets['validation'][0] | |
| len_val_dataset = len(val_dataset) | |
| # print('val datasets length: ', len_val_dataset) | |
| # data_sample = val_dataset[0] | |
| # batch = data_to_batch(data_sample) | |
| # output = arch(batch, epoch=0) | |
| # print('output: ', output) | |
| # output_depth = output['predictions']['depth'][0][0] | |
| # print('output_depth: ', output_depth) | |
| # output_depth = output_depth.squeeze(0).squeeze(0).permute(1,2,0) | |
| # print('output_depth shape: ', output_depth.shape) | |
| def sample_data_idx(): | |
| return random.randint(0, len_val_dataset-1) | |
| def display_images_from_idx(idx): | |
| rgbs = val_dataset[int(idx)]['rgb'] | |
| return [np.array(rgb.permute(1,2,0)) for rgb in rgbs.values()] | |
| def infer_depth_from_idx(idx): | |
| data_sample = val_dataset[int(idx)] | |
| batch = data_to_batch(data_sample) | |
| output = arch(batch, epoch=0) | |
| output_depths = output['predictions']['depth'] | |
| return [colormap_depth(output_depth[0].squeeze(0).squeeze(0).squeeze(0).detach().numpy()) for output_depth in output_depths.values()] | |
| with gr.Blocks() as demo: | |
| # layout | |
| img_box = gr.Gallery(label="Sampled Images").style(grid=[2], height="auto") | |
| data_idx_box = gr.Textbox( | |
| label="Sampled Data Index", | |
| placeholder="Number between {} and {}".format(0, len_val_dataset-1), | |
| interactive=True | |
| ) | |
| sample_btn = gr.Button('Sample Dataset') | |
| depth_box = gr.Gallery(label="Infered Depth").style(grid=[2], height="auto") | |
| infer_btn = gr.Button('Depth Infer') | |
| # actions | |
| sample_btn.click( | |
| fn=sample_data_idx, | |
| inputs=None, | |
| outputs=data_idx_box | |
| ).success( | |
| fn=display_images_from_idx, | |
| inputs=data_idx_box, | |
| outputs=img_box, | |
| ) | |
| infer_btn.click( | |
| fn=infer_depth_from_idx, | |
| inputs=data_idx_box, | |
| outputs=depth_box | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |