Spaces:
Configuration error
Configuration error
| import sys | |
| import os | |
| from subprocess import check_call | |
| import tempfile | |
| from os.path import basename, splitext, join | |
| from io import BytesIO | |
| import numpy as np | |
| from scipy.spatial import KDTree | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.transforms.functional import to_tensor, to_pil_image | |
| from einops import rearrange | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from extern.ZoeDepth.zoedepth.utils.misc import colorize | |
| from gradio_model3dgscamera import Model3DGSCamera | |
| IMAGE_SIZE = 512 | |
| NEAR, FAR = 0.01, 100 | |
| FOVY = np.deg2rad(55) | |
| def download_models(): | |
| models = [ | |
| { | |
| 'repo': 'stabilityai/sd-vae-ft-mse', | |
| 'sub': None, | |
| 'dst': 'checkpoints/sd-vae-ft-mse', | |
| 'files': ['config.json', 'diffusion_pytorch_model.safetensors'], | |
| 'token': None | |
| }, | |
| { | |
| 'repo': 'lambdalabs/sd-image-variations-diffusers', | |
| 'sub': 'image_encoder', | |
| 'dst': 'checkpoints', | |
| 'files': ['config.json', 'pytorch_model.bin'], | |
| 'token': None | |
| }, | |
| { | |
| 'repo': 'Sony/genwarp', | |
| 'sub': 'multi1', | |
| 'dst': 'checkpoints', | |
| 'files': ['config.json', 'denoising_unet.pth', 'pose_guider.pth', 'reference_unet.pth'], | |
| 'token': None | |
| } | |
| ] | |
| for model in models: | |
| for file in model['files']: | |
| hf_hub_download( | |
| repo_id=model['repo'], | |
| subfolder=model['sub'], | |
| filename=file, | |
| local_dir=model['dst'], | |
| token=model['token'] | |
| ) | |
| # Crop the image to the shorter side. | |
| def crop(img: Image) -> Image: | |
| W, H = img.size | |
| if W < H: | |
| left, right = 0, W | |
| top, bottom = np.ceil((H - W) / 2.), np.floor((H - W) / 2.) + W | |
| else: | |
| left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H | |
| top, bottom = 0, H | |
| return img.crop((left, top, right, bottom)) | |
| def unproject(depth): | |
| fovy_deg = 55 | |
| H, W = depth.shape[2:4] | |
| mean_depth = depth.mean(dim=(2, 3)).squeeze().item() | |
| viewport_mtx = get_viewport_matrix( | |
| IMAGE_SIZE, IMAGE_SIZE, | |
| batch_size=1 | |
| ).to(depth) | |
| # Projection matrix. | |
| fovy = torch.ones(1) * FOVY | |
| proj_mtx = get_projection_matrix( | |
| fovy=fovy, | |
| aspect_wh=1., | |
| near=NEAR, | |
| far=FAR | |
| ).to(depth) | |
| view_mtx = camera_lookat( | |
| torch.tensor([[0., 0., 0.]]), | |
| torch.tensor([[0., 0., 1.]]), | |
| torch.tensor([[0., -1., 0.]]) | |
| ).to(depth) | |
| scr_mtx = (viewport_mtx @ proj_mtx).to(depth) | |
| grid = torch.stack(torch.meshgrid( | |
| torch.arange(W), torch.arange(H), indexing='xy'), dim=-1 | |
| ).to(depth)[None] # BHW2 | |
| screen = F.pad(grid, (0, 1), 'constant', 0) | |
| screen = F.pad(screen, (0, 1), 'constant', 1) | |
| screen_flat = rearrange(screen, 'b h w c -> b (h w) c') | |
| eye = screen_flat @ torch.linalg.inv_ex( | |
| scr_mtx.float() | |
| )[0].mT.to(depth) | |
| eye = eye * rearrange(depth, 'b c h w -> b (h w) c') | |
| eye[..., 3] = 1 | |
| points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth) | |
| points = points[0, :, :3] | |
| # Translate to the origin. | |
| points[..., 2] -= mean_depth | |
| camera_pos = (0, 0, -mean_depth) | |
| view_mtx = camera_lookat( | |
| torch.tensor([[0., 0., -mean_depth]]), | |
| torch.tensor([[0., 0., 0.]]), | |
| torch.tensor([[0., -1., 0.]]) | |
| ).to(depth) | |
| return points, camera_pos, view_mtx, proj_mtx | |
| def calc_dist2(points: np.ndarray): | |
| dists, _ = KDTree(points).query(points, k=4) | |
| mean_dists = (dists[:, 1:] ** 2).mean(1) | |
| return mean_dists | |
| def save_as_splat( | |
| filepath: str, | |
| xyz: np.ndarray, | |
| rgb: np.ndarray | |
| ): | |
| # To gaussian splat | |
| inv_sigmoid = lambda x: np.log(x / (1 - x)) | |
| dist2 = np.clip(calc_dist2(xyz), a_min=0.0000001, a_max=None) | |
| scales = np.repeat(np.log(np.sqrt(dist2))[..., np.newaxis], 3, axis=1) | |
| rots = np.zeros((xyz.shape[0], 4)) | |
| rots[:, 0] = 1 | |
| opacities = inv_sigmoid(0.1 * np.ones((xyz.shape[0], 1))) | |
| sorted_indices = np.argsort(( | |
| -np.exp(np.sum(scales, axis=-1, keepdims=True)) | |
| / (1 + np.exp(-opacities)) | |
| ).squeeze()) | |
| buffer = BytesIO() | |
| for idx in sorted_indices: | |
| position = xyz[idx] | |
| scale = np.exp(scales[idx]).astype(np.float32) | |
| rot = rots[idx].astype(np.float32) | |
| color = np.concatenate( | |
| (rgb[idx], 1 / (1 + np.exp(-opacities[idx]))), | |
| axis=-1 | |
| ) | |
| buffer.write(position.tobytes()) | |
| buffer.write(scale.tobytes()) | |
| buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) | |
| buffer.write( | |
| ((rot / np.linalg.norm(rot)) * 128 + 128) | |
| .clip(0, 255) | |
| .astype(np.uint8) | |
| .tobytes() | |
| ) | |
| with open(filepath, "wb") as f: | |
| f.write(buffer.getvalue()) | |
| def view_from_rt(position, rotation): | |
| t = np.array(position) | |
| euler = np.array(rotation) | |
| cx = np.cos(euler[0]) | |
| sx = np.sin(euler[0]) | |
| cy = np.cos(euler[1]) | |
| sy = np.sin(euler[1]) | |
| cz = np.cos(euler[2]) | |
| sz = np.sin(euler[2]) | |
| R = np.array([ | |
| cy * cz + sy * sx * sz, | |
| -cy * sz + sy * sx * cz, | |
| sy * cx, | |
| cx * sz, | |
| cx * cz, | |
| -sx, | |
| -sy * cz + cy * sx * sz, | |
| sy * sz + cy * sx * cz, | |
| cy * cx | |
| ]) | |
| view_mtx = np.array([ | |
| [R[0], R[1], R[2], 0], | |
| [R[3], R[4], R[5], 0], | |
| [R[6], R[7], R[8], 0], | |
| [ | |
| -t[0] * R[0] - t[1] * R[3] - t[2] * R[6], | |
| -t[0] * R[1] - t[1] * R[4] - t[2] * R[7], | |
| -t[0] * R[2] - t[1] * R[5] - t[2] * R[8], | |
| 1 | |
| ] | |
| ]).T | |
| B = np.array([ | |
| [1, 0, 0, 0], | |
| [0, -1, 0, 0], | |
| [0, 0, -1, 0], | |
| [0, 0, 0, 1] | |
| ]) | |
| return B @ view_mtx | |
| # Setup. | |
| download_models() | |
| mde = torch.hub.load( | |
| './extern/ZoeDepth', | |
| 'ZoeD_N', | |
| source='local', | |
| pretrained=True, | |
| trust_repo=True | |
| ) | |
| import spaces | |
| check_call([ | |
| sys.executable, '-m', 'pip', 'install', | |
| 'extern/splatting-0.0.1-py3-none-any.whl' | |
| ]) | |
| from genwarp import GenWarp | |
| from genwarp.ops import ( | |
| camera_lookat, get_projection_matrix, get_viewport_matrix | |
| ) | |
| # GenWarp | |
| genwarp_cfg = dict( | |
| pretrained_model_path='checkpoints', | |
| checkpoint_name='multi1', | |
| half_precision_weights=True | |
| ) | |
| genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu') | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| with gr.Blocks( | |
| title='GenWarp Demo', | |
| css='img {display: inline;}' | |
| ) as demo: | |
| # Internal states. | |
| src_image = gr.State() | |
| src_depth = gr.State() | |
| proj_mtx = gr.State() | |
| src_view_mtx = gr.State() | |
| # Blocks. | |
| gr.Markdown( | |
| """ | |
| # GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping | |
| [](https://genwarp-nvs.github.io/) | |
| [](https://huggingface.co/spaces/Sony/GenWarp) | |
| [](https://github.com/sony/genwarp/) | |
| [](https://huggingface.co/Sony/genwarp) | |
| [](https://arxiv.org/abs/2405.17251) | |
| ## Introduction | |
| This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer the [paper](https://arxiv.org/abs/2405.17251). | |
| ## How to Use | |
| 1. Upload a reference image to "Reference Input" | |
| - You can also select a image from "Examples" | |
| 2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer | |
| 3. Hit "Generate a novel view" button and check the result | |
| """ | |
| ) | |
| file = gr.File(label='Reference Input', file_types=['image']) | |
| examples = gr.Examples( | |
| examples=['./assets/pexels-heyho-5998120_19mm.jpg', | |
| './assets/pexels-itsterrymag-12639296_24mm.jpg'], | |
| inputs=file | |
| ) | |
| with gr.Row(): | |
| image_widget = gr.Image( | |
| label='Reference View', type='filepath', | |
| interactive=False | |
| ) | |
| depth_widget = gr.Image(label='Estimated Depth', type='pil') | |
| viewer = Model3DGSCamera( | |
| label = 'Unprojected 3DGS', | |
| width=IMAGE_SIZE, | |
| height=IMAGE_SIZE, | |
| camera_width=IMAGE_SIZE, | |
| camera_height=IMAGE_SIZE, | |
| camera_fx=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2., | |
| camera_fy=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2., | |
| camera_near=NEAR, | |
| camera_far=FAR | |
| ) | |
| button = gr.Button('Generate a novel view', size='lg', variant='primary') | |
| with gr.Row(): | |
| warped_widget = gr.Image( | |
| label='Warped Image', type='pil', interactive=False | |
| ) | |
| gen_widget = gr.Image( | |
| label='Generated View', type='pil', interactive=False | |
| ) | |
| # Callbacks | |
| def cb_mde(image_file: str): | |
| image = to_tensor(crop(Image.open( | |
| image_file | |
| ).convert('RGB')).resize((IMAGE_SIZE, IMAGE_SIZE)))[None].cuda() | |
| depth = mde.cuda().infer(image) | |
| depth_image = to_pil_image(colorize(depth[0])) | |
| return to_pil_image(image[0]), depth_image, image.cpu().detach(), depth.cpu().detach() | |
| def cb_3d(image, depth, image_file): | |
| xyz, camera_pos, view_mtx, proj_mtx = unproject(depth.cuda()) | |
| rgb = rearrange(image, 'b c h w -> b (h w) c')[0] | |
| splat_file = join(tmpdir, f'./{splitext(basename(image_file))[0]}.splat') | |
| save_as_splat(splat_file, xyz.cpu().detach().numpy(), rgb.cpu().detach().numpy()) | |
| return (splat_file, camera_pos, None), view_mtx.cpu().detach(), proj_mtx.cpu().detach() | |
| def cb_generate(viewer, image, depth, src_view_mtx, proj_mtx): | |
| image = image.cuda() | |
| depth = depth.cuda() | |
| src_view_mtx = src_view_mtx.cuda() | |
| proj_mtx = proj_mtx.cuda() | |
| src_camera_pos = viewer[1] | |
| src_camera_rot = viewer[2] | |
| tar_view_mtx = view_from_rt(src_camera_pos, src_camera_rot) | |
| tar_view_mtx = torch.from_numpy(tar_view_mtx).to(image) | |
| rel_view_mtx = ( | |
| tar_view_mtx @ torch.linalg.inv(src_view_mtx.to(image)) | |
| ).to(image) | |
| # GenWarp. | |
| renders = genwarp_nvs.to('cuda')( | |
| src_image=image.half(), | |
| src_depth=depth.half(), | |
| rel_view_mtx=rel_view_mtx.half(), | |
| src_proj_mtx=proj_mtx.half(), | |
| tar_proj_mtx=proj_mtx.half() | |
| ) | |
| warped = renders['warped'] | |
| synthesized = renders['synthesized'] | |
| warped_pil = to_pil_image(warped[0]) | |
| synthesized_pil = to_pil_image(synthesized[0]) | |
| return warped_pil, synthesized_pil | |
| # Events | |
| file.change( | |
| fn=cb_mde, | |
| inputs=file, | |
| outputs=[image_widget, depth_widget, src_image, src_depth] | |
| ).then( | |
| fn=cb_3d, | |
| inputs=[src_image, src_depth, image_widget], | |
| outputs=[viewer, src_view_mtx, proj_mtx]) | |
| button.click( | |
| fn=cb_generate, | |
| inputs=[viewer, src_image, src_depth, src_view_mtx, proj_mtx], | |
| outputs=[warped_widget, gen_widget]) | |
| if __name__ == '__main__': | |
| demo.launch() | |