Spaces:
Runtime error
Runtime error
kxhit
commited on
Commit
·
8cf8c7b
1
Parent(s):
5f093a6
app
Browse files- EscherNet_Demo_Readme.md +0 -5
- app.py +780 -0
EscherNet_Demo_Readme.md
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
Run EscherNet using Dust3R log results, need to set `data_dir` and run:
|
| 2 |
-
```commandline
|
| 3 |
-
bash ./demo_dust3r.sh
|
| 4 |
-
```
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import rembg
|
| 7 |
+
import numpy as np
|
| 8 |
+
import math
|
| 9 |
+
import open3d as o3d
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import torch
|
| 12 |
+
import torchvision
|
| 13 |
+
import trimesh
|
| 14 |
+
from skimage.io import imsave
|
| 15 |
+
import imageio
|
| 16 |
+
import cv2
|
| 17 |
+
import matplotlib.pyplot as pl
|
| 18 |
+
pl.ion()
|
| 19 |
+
|
| 20 |
+
CaPE_TYPE = "6DoF"
|
| 21 |
+
device = 'cuda' #if torch.cuda.is_available() else 'cpu'
|
| 22 |
+
weight_dtype = torch.float16
|
| 23 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
| 24 |
+
|
| 25 |
+
# EscherNet
|
| 26 |
+
# create angles in archimedean spiral with N steps
|
| 27 |
+
def get_archimedean_spiral(sphere_radius, num_steps=250):
|
| 28 |
+
# x-z plane, around upper y
|
| 29 |
+
'''
|
| 30 |
+
https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
|
| 31 |
+
'''
|
| 32 |
+
a = 40
|
| 33 |
+
r = sphere_radius
|
| 34 |
+
|
| 35 |
+
translations = []
|
| 36 |
+
angles = []
|
| 37 |
+
|
| 38 |
+
# i = a / 2
|
| 39 |
+
i = 0.01
|
| 40 |
+
while i < a:
|
| 41 |
+
theta = i / a * math.pi
|
| 42 |
+
x = r * math.sin(theta) * math.cos(-i)
|
| 43 |
+
z = r * math.sin(-theta + math.pi) * math.sin(-i)
|
| 44 |
+
y = r * - math.cos(theta)
|
| 45 |
+
|
| 46 |
+
# translations.append((x, y, z)) # origin
|
| 47 |
+
translations.append((x, z, -y))
|
| 48 |
+
angles.append([np.rad2deg(-i), np.rad2deg(theta)])
|
| 49 |
+
|
| 50 |
+
# i += a / (2 * num_steps)
|
| 51 |
+
i += a / (1 * num_steps)
|
| 52 |
+
|
| 53 |
+
return np.array(translations), np.stack(angles)
|
| 54 |
+
|
| 55 |
+
def look_at(origin, target, up):
|
| 56 |
+
forward = (target - origin)
|
| 57 |
+
forward = forward / np.linalg.norm(forward)
|
| 58 |
+
right = np.cross(up, forward)
|
| 59 |
+
right = right / np.linalg.norm(right)
|
| 60 |
+
new_up = np.cross(forward, right)
|
| 61 |
+
rotation_matrix = np.column_stack((right, new_up, -forward, target))
|
| 62 |
+
matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
|
| 63 |
+
return matrix
|
| 64 |
+
|
| 65 |
+
import einops
|
| 66 |
+
import sys
|
| 67 |
+
|
| 68 |
+
sys.path.insert(0, "./6DoF/") # TODO change it when deploying
|
| 69 |
+
# use the customized diffusers modules
|
| 70 |
+
from diffusers import DDIMScheduler
|
| 71 |
+
from dataset import get_pose
|
| 72 |
+
from CN_encoder import CN_encoder
|
| 73 |
+
from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
|
| 74 |
+
|
| 75 |
+
pretrained_model_name_or_path = "kxic/EscherNet_demo"
|
| 76 |
+
resolution = 256
|
| 77 |
+
h,w = resolution,resolution
|
| 78 |
+
guidance_scale = 3.0
|
| 79 |
+
radius = 2.2
|
| 80 |
+
bg_color = [1., 1., 1., 1.]
|
| 81 |
+
image_transforms = torchvision.transforms.Compose(
|
| 82 |
+
[
|
| 83 |
+
torchvision.transforms.Resize((resolution, resolution)), # 256, 256
|
| 84 |
+
torchvision.transforms.ToTensor(),
|
| 85 |
+
torchvision.transforms.Normalize([0.5], [0.5])
|
| 86 |
+
]
|
| 87 |
+
)
|
| 88 |
+
xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
|
| 89 |
+
# only half toop
|
| 90 |
+
xyzs_spiral = xyzs_spiral[:100]
|
| 91 |
+
angles_spiral = angles_spiral[:100]
|
| 92 |
+
|
| 93 |
+
# Init pipeline
|
| 94 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
|
| 95 |
+
image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
|
| 96 |
+
pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
|
| 97 |
+
pretrained_model_name_or_path,
|
| 98 |
+
revision=None,
|
| 99 |
+
scheduler=scheduler,
|
| 100 |
+
image_encoder=None,
|
| 101 |
+
safety_checker=None,
|
| 102 |
+
feature_extractor=None,
|
| 103 |
+
torch_dtype=weight_dtype,
|
| 104 |
+
)
|
| 105 |
+
pipeline.image_encoder = image_encoder.to(weight_dtype)
|
| 106 |
+
pipeline = pipeline.to(device)
|
| 107 |
+
pipeline.set_progress_bar_config(disable=False)
|
| 108 |
+
|
| 109 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
| 110 |
+
# enable vae slicing
|
| 111 |
+
pipeline.enable_vae_slicing()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@spaces.GPU(duration=120)
|
| 117 |
+
def run_eschernet(tmpdirname, eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
|
| 118 |
+
# set the random seed
|
| 119 |
+
generator = torch.Generator(device=device).manual_seed(sample_seed)
|
| 120 |
+
T_out = nvs_num
|
| 121 |
+
T_in = len(eschernet_input_dict['imgs'])
|
| 122 |
+
####### output pose
|
| 123 |
+
# TODO choose T_out number of poses sequentially from the spiral
|
| 124 |
+
xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
|
| 125 |
+
angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
|
| 126 |
+
|
| 127 |
+
####### input's max radius for translation scaling
|
| 128 |
+
radii = eschernet_input_dict['radii']
|
| 129 |
+
max_t = np.max(radii)
|
| 130 |
+
min_t = np.min(radii)
|
| 131 |
+
|
| 132 |
+
####### input pose
|
| 133 |
+
pose_in = []
|
| 134 |
+
for T_in_index in range(T_in):
|
| 135 |
+
pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
|
| 136 |
+
pose[1:3, :] *= -1 # coordinate system conversion
|
| 137 |
+
pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
|
| 138 |
+
pose_in.append(torch.from_numpy(pose))
|
| 139 |
+
|
| 140 |
+
####### input image
|
| 141 |
+
img = eschernet_input_dict['imgs'] / 255.
|
| 142 |
+
img[img[:, :, :, -1] == 0.] = bg_color
|
| 143 |
+
# TODO batch image_transforms
|
| 144 |
+
input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
|
| 145 |
+
|
| 146 |
+
####### nvs pose
|
| 147 |
+
pose_out = []
|
| 148 |
+
for T_out_index in range(T_out):
|
| 149 |
+
azimuth, polar = angles_out[T_out_index]
|
| 150 |
+
if CaPE_TYPE == "4DoF":
|
| 151 |
+
pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
|
| 152 |
+
elif CaPE_TYPE == "6DoF":
|
| 153 |
+
pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
|
| 154 |
+
pose = np.linalg.inv(pose)
|
| 155 |
+
pose[2, :] *= -1
|
| 156 |
+
pose_out.append(torch.from_numpy(get_pose(pose)))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# [B, T, C, H, W]
|
| 161 |
+
input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
|
| 162 |
+
# [B, T, 4]
|
| 163 |
+
pose_in = np.stack(pose_in)
|
| 164 |
+
pose_out = np.stack(pose_out)
|
| 165 |
+
|
| 166 |
+
if CaPE_TYPE == "6DoF":
|
| 167 |
+
pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
|
| 168 |
+
pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
|
| 169 |
+
pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
|
| 170 |
+
pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
|
| 171 |
+
|
| 172 |
+
pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
|
| 173 |
+
pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
|
| 174 |
+
|
| 175 |
+
input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
|
| 176 |
+
assert T_in == input_image.shape[0]
|
| 177 |
+
assert T_in == pose_in.shape[1]
|
| 178 |
+
assert T_out == pose_out.shape[1]
|
| 179 |
+
|
| 180 |
+
# run inference
|
| 181 |
+
if CaPE_TYPE == "6DoF":
|
| 182 |
+
with torch.autocast("cuda"):
|
| 183 |
+
image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
|
| 184 |
+
poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
|
| 185 |
+
height=h, width=w, T_in=T_in, T_out=T_out,
|
| 186 |
+
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
| 187 |
+
output_type="numpy").images
|
| 188 |
+
elif CaPE_TYPE == "4DoF":
|
| 189 |
+
with torch.autocast("cuda"):
|
| 190 |
+
image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in],
|
| 191 |
+
height=h, width=w, T_in=T_in, T_out=T_out,
|
| 192 |
+
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
| 193 |
+
output_type="numpy").images
|
| 194 |
+
|
| 195 |
+
# save output image
|
| 196 |
+
output_dir = os.path.join(tmpdirname, "eschernet")
|
| 197 |
+
if os.path.exists(output_dir):
|
| 198 |
+
shutil.rmtree(output_dir)
|
| 199 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 200 |
+
# save to N imgs
|
| 201 |
+
for i in range(T_out):
|
| 202 |
+
imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
|
| 203 |
+
# make a gif
|
| 204 |
+
frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
|
| 205 |
+
frame_one = frames[0]
|
| 206 |
+
frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
|
| 207 |
+
save_all=True, duration=50, loop=1)
|
| 208 |
+
|
| 209 |
+
# get a video
|
| 210 |
+
video_path = os.path.join(output_dir, "output.mp4")
|
| 211 |
+
imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
return image, video_path
|
| 215 |
+
|
| 216 |
+
# TODO mesh it
|
| 217 |
+
@spaces.GPU(duration=120)
|
| 218 |
+
def make3d():
|
| 219 |
+
pass
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
############################ Dust3r as Pose Estimation ############################
|
| 224 |
+
from scipy.spatial.transform import Rotation
|
| 225 |
+
import copy
|
| 226 |
+
|
| 227 |
+
from dust3r.inference import inference
|
| 228 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
| 229 |
+
from dust3r.image_pairs import make_pairs
|
| 230 |
+
from dust3r.utils.image import load_images, rgb
|
| 231 |
+
from dust3r.utils.device import to_numpy
|
| 232 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
| 233 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
| 234 |
+
|
| 235 |
+
import functools
|
| 236 |
+
import math
|
| 237 |
+
|
| 238 |
+
@spaces.GPU
|
| 239 |
+
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
| 240 |
+
cam_color=None, as_pointcloud=False,
|
| 241 |
+
transparent_cams=False, silent=False, same_focals=False):
|
| 242 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
|
| 243 |
+
if not same_focals:
|
| 244 |
+
assert (len(cams2world) == len(focals))
|
| 245 |
+
pts3d = to_numpy(pts3d)
|
| 246 |
+
imgs = to_numpy(imgs)
|
| 247 |
+
focals = to_numpy(focals)
|
| 248 |
+
cams2world = to_numpy(cams2world)
|
| 249 |
+
|
| 250 |
+
scene = trimesh.Scene()
|
| 251 |
+
|
| 252 |
+
# add axes
|
| 253 |
+
scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
|
| 254 |
+
|
| 255 |
+
# full pointcloud
|
| 256 |
+
if as_pointcloud:
|
| 257 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
| 258 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
| 259 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
|
| 260 |
+
scene.add_geometry(pct)
|
| 261 |
+
else:
|
| 262 |
+
meshes = []
|
| 263 |
+
for i in range(len(imgs)):
|
| 264 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
|
| 265 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
| 266 |
+
scene.add_geometry(mesh)
|
| 267 |
+
|
| 268 |
+
# add each camera
|
| 269 |
+
for i, pose_c2w in enumerate(cams2world):
|
| 270 |
+
if isinstance(cam_color, list):
|
| 271 |
+
camera_edge_color = cam_color[i]
|
| 272 |
+
else:
|
| 273 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
| 274 |
+
if same_focals:
|
| 275 |
+
focal = focals[0]
|
| 276 |
+
else:
|
| 277 |
+
focal = focals[i]
|
| 278 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
| 279 |
+
None if transparent_cams else imgs[i], focal,
|
| 280 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
| 281 |
+
|
| 282 |
+
rot = np.eye(4)
|
| 283 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
| 284 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
| 285 |
+
outfile = os.path.join(outdir, 'scene.glb')
|
| 286 |
+
if not silent:
|
| 287 |
+
print('(exporting 3D scene to', outfile, ')')
|
| 288 |
+
scene.export(file_obj=outfile)
|
| 289 |
+
return outfile
|
| 290 |
+
|
| 291 |
+
@spaces.GPU(duration=120)
|
| 292 |
+
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
|
| 293 |
+
clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
|
| 294 |
+
"""
|
| 295 |
+
extract 3D_model (glb file) from a reconstructed scene
|
| 296 |
+
"""
|
| 297 |
+
if scene is None:
|
| 298 |
+
return None
|
| 299 |
+
# post processes
|
| 300 |
+
if clean_depth:
|
| 301 |
+
scene = scene.clean_pointcloud()
|
| 302 |
+
if mask_sky:
|
| 303 |
+
scene = scene.mask_sky()
|
| 304 |
+
|
| 305 |
+
# get optimized values from scene
|
| 306 |
+
rgbimg = to_numpy(scene.imgs)
|
| 307 |
+
focals = to_numpy(scene.get_focals().cpu())
|
| 308 |
+
# cams2world = to_numpy(scene.get_im_poses().cpu())
|
| 309 |
+
# TODO use the vis_poses
|
| 310 |
+
cams2world = scene.vis_poses
|
| 311 |
+
|
| 312 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
| 313 |
+
# pts3d = to_numpy(scene.get_pts3d())
|
| 314 |
+
# TODO use the vis_poses
|
| 315 |
+
pts3d = scene.vis_pts3d
|
| 316 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
| 317 |
+
msk = to_numpy(scene.get_masks())
|
| 318 |
+
|
| 319 |
+
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
| 320 |
+
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
|
| 321 |
+
same_focals=same_focals)
|
| 322 |
+
|
| 323 |
+
@spaces.GPU(duration=120)
|
| 324 |
+
def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr,
|
| 325 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
| 326 |
+
scenegraph_type, winsize, refid, same_focals):
|
| 327 |
+
"""
|
| 328 |
+
from a list of images, run dust3r inference, global aligner.
|
| 329 |
+
then run get_3D_model_from_scene
|
| 330 |
+
"""
|
| 331 |
+
# remove the directory if it already exists
|
| 332 |
+
if os.path.exists(outdir):
|
| 333 |
+
shutil.rmtree(outdir)
|
| 334 |
+
os.makedirs(outdir, exist_ok=True)
|
| 335 |
+
imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True)
|
| 336 |
+
if len(imgs) == 1:
|
| 337 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
| 338 |
+
imgs[1]['idx'] = 1
|
| 339 |
+
if scenegraph_type == "swin":
|
| 340 |
+
scenegraph_type = scenegraph_type + "-" + str(winsize)
|
| 341 |
+
elif scenegraph_type == "oneref":
|
| 342 |
+
scenegraph_type = scenegraph_type + "-" + str(refid)
|
| 343 |
+
|
| 344 |
+
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
|
| 345 |
+
output = inference(pairs, model, device, batch_size=1, verbose=not silent)
|
| 346 |
+
|
| 347 |
+
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
|
| 348 |
+
scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
|
| 349 |
+
lr = 0.01
|
| 350 |
+
|
| 351 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
| 352 |
+
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
|
| 353 |
+
|
| 354 |
+
# outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 355 |
+
# clean_depth, transparent_cams, cam_size, same_focals=same_focals)
|
| 356 |
+
|
| 357 |
+
# also return rgb, depth and confidence imgs
|
| 358 |
+
# depth is normalized with the max value for all images
|
| 359 |
+
# we apply the jet colormap on the confidence maps
|
| 360 |
+
rgbimg = scene.imgs
|
| 361 |
+
# depths = to_numpy(scene.get_depthmaps())
|
| 362 |
+
# confs = to_numpy([c for c in scene.im_conf])
|
| 363 |
+
# cmap = pl.get_cmap('jet')
|
| 364 |
+
# depths_max = max([d.max() for d in depths])
|
| 365 |
+
# depths = [d / depths_max for d in depths]
|
| 366 |
+
# confs_max = max([d.max() for d in confs])
|
| 367 |
+
# confs = [cmap(d / confs_max) for d in confs]
|
| 368 |
+
|
| 369 |
+
imgs = []
|
| 370 |
+
rgbaimg = []
|
| 371 |
+
for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
|
| 372 |
+
imgs.append(rgbimg[i])
|
| 373 |
+
# imgs.append(rgb(depths[i]))
|
| 374 |
+
# imgs.append(rgb(confs[i]))
|
| 375 |
+
# imgs.append(imgs_rgba[i])
|
| 376 |
+
if len(imgs_rgba) == 1 and i == 1:
|
| 377 |
+
imgs.append(imgs_rgba[0])
|
| 378 |
+
rgbaimg.append(np.array(imgs_rgba[0]))
|
| 379 |
+
else:
|
| 380 |
+
imgs.append(imgs_rgba[i])
|
| 381 |
+
rgbaimg.append(np.array(imgs_rgba[i]))
|
| 382 |
+
|
| 383 |
+
rgbaimg = np.array(rgbaimg)
|
| 384 |
+
|
| 385 |
+
# for eschernet
|
| 386 |
+
# get optimized values from scene
|
| 387 |
+
rgbimg = to_numpy(scene.imgs)
|
| 388 |
+
focals = to_numpy(scene.get_focals().cpu())
|
| 389 |
+
cams2world = to_numpy(scene.get_im_poses().cpu())
|
| 390 |
+
|
| 391 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
| 392 |
+
pts3d = to_numpy(scene.get_pts3d())
|
| 393 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
| 394 |
+
msk = to_numpy(scene.get_masks())
|
| 395 |
+
obj_mask = rgbaimg[..., 3] > 0
|
| 396 |
+
|
| 397 |
+
# TODO set global coordinate system at the center of the scene, z-axis is up
|
| 398 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
|
| 399 |
+
pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
|
| 400 |
+
centroid = np.mean(pts_obj, axis=0) # obj center
|
| 401 |
+
obj2world = np.eye(4)
|
| 402 |
+
obj2world[:3, 3] = -centroid # T_wc
|
| 403 |
+
|
| 404 |
+
# get z_up vector
|
| 405 |
+
# TODO fit a plane and get the normal vector
|
| 406 |
+
pcd = o3d.geometry.PointCloud()
|
| 407 |
+
pcd.points = o3d.utility.Vector3dVector(pts)
|
| 408 |
+
plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
|
| 409 |
+
# get the normalised normal vector dim = 3
|
| 410 |
+
normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
|
| 411 |
+
# the normal direction should be pointing up
|
| 412 |
+
if normal[1] < 0:
|
| 413 |
+
normal = -normal
|
| 414 |
+
# print("normal", normal)
|
| 415 |
+
|
| 416 |
+
# # TODO z-up 180
|
| 417 |
+
# z_up = np.array([[1,0,0,0],
|
| 418 |
+
# [0,-1,0,0],
|
| 419 |
+
# [0,0,-1,0],
|
| 420 |
+
# [0,0,0,1]])
|
| 421 |
+
# obj2world = z_up @ obj2world
|
| 422 |
+
|
| 423 |
+
# # avg the y
|
| 424 |
+
# z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
|
| 425 |
+
# # import pdb; pdb.set_trace()
|
| 426 |
+
# rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
|
| 427 |
+
# rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
|
| 428 |
+
# rot = Rotation.from_rotvec(rot_angle * rot_axis)
|
| 429 |
+
# z_up = np.eye(4)
|
| 430 |
+
# z_up[:3, :3] = rot.as_matrix()
|
| 431 |
+
|
| 432 |
+
# get the rotation matrix from normal to z-axis
|
| 433 |
+
z_axis = np.array([0, 0, 1])
|
| 434 |
+
rot_axis = np.cross(normal, z_axis)
|
| 435 |
+
rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
|
| 436 |
+
rot = Rotation.from_rotvec(rot_angle * rot_axis)
|
| 437 |
+
z_up = np.eye(4)
|
| 438 |
+
z_up[:3, :3] = rot.as_matrix()
|
| 439 |
+
obj2world = z_up @ obj2world
|
| 440 |
+
# flip 180
|
| 441 |
+
flip_rot = np.array([[1, 0, 0, 0],
|
| 442 |
+
[0, -1, 0, 0],
|
| 443 |
+
[0, 0, -1, 0],
|
| 444 |
+
[0, 0, 0, 1]])
|
| 445 |
+
obj2world = flip_rot @ obj2world
|
| 446 |
+
|
| 447 |
+
# get new cams2obj
|
| 448 |
+
cams2obj = []
|
| 449 |
+
for i, cam2world in enumerate(cams2world):
|
| 450 |
+
cams2obj.append(obj2world @ cam2world)
|
| 451 |
+
# TODO transform pts3d to the new coordinate system
|
| 452 |
+
for i, pts in enumerate(pts3d):
|
| 453 |
+
pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
|
| 454 |
+
-1)) \
|
| 455 |
+
.reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
|
| 456 |
+
cams2world = np.array(cams2obj)
|
| 457 |
+
# TODO rewrite hack
|
| 458 |
+
scene.vis_poses = cams2world.copy()
|
| 459 |
+
scene.vis_pts3d = pts3d.copy()
|
| 460 |
+
|
| 461 |
+
# TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
|
| 462 |
+
for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
|
| 463 |
+
np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
|
| 464 |
+
pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
|
| 465 |
+
pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
|
| 466 |
+
# np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
|
| 467 |
+
# save the min/max radius of camera
|
| 468 |
+
radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
|
| 469 |
+
np.save(os.path.join(outdir, "radii.npy"), radii)
|
| 470 |
+
|
| 471 |
+
eschernet_input = {"poses": cams2world,
|
| 472 |
+
"radii": radii,
|
| 473 |
+
"imgs": rgbaimg}
|
| 474 |
+
|
| 475 |
+
outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 476 |
+
clean_depth, transparent_cams, cam_size, same_focals=same_focals)
|
| 477 |
+
|
| 478 |
+
return scene, outfile, imgs, eschernet_input
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
|
| 482 |
+
num_files = len(inputfiles) if inputfiles is not None else 1
|
| 483 |
+
max_winsize = max(1, math.ceil((num_files - 1) / 2))
|
| 484 |
+
if scenegraph_type == "swin":
|
| 485 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
| 486 |
+
minimum=1, maximum=max_winsize, step=1, visible=True)
|
| 487 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
| 488 |
+
maximum=num_files - 1, step=1, visible=False)
|
| 489 |
+
elif scenegraph_type == "oneref":
|
| 490 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
| 491 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
| 492 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
| 493 |
+
maximum=num_files - 1, step=1, visible=True)
|
| 494 |
+
else:
|
| 495 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
| 496 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
| 497 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
| 498 |
+
maximum=num_files - 1, step=1, visible=False)
|
| 499 |
+
return winsize, refid
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def get_examples(path):
|
| 503 |
+
objs = []
|
| 504 |
+
for obj_name in sorted(os.listdir(path)):
|
| 505 |
+
img_files = []
|
| 506 |
+
for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
|
| 507 |
+
img_files.append(os.path.join(path, obj_name, img_file))
|
| 508 |
+
objs.append([img_files])
|
| 509 |
+
print("objs = ", objs)
|
| 510 |
+
return objs
|
| 511 |
+
|
| 512 |
+
def preview_input(inputfiles):
|
| 513 |
+
if inputfiles is None:
|
| 514 |
+
return None
|
| 515 |
+
imgs = []
|
| 516 |
+
for img_file in inputfiles:
|
| 517 |
+
img = pl.imread(img_file)
|
| 518 |
+
imgs.append(img)
|
| 519 |
+
return imgs
|
| 520 |
+
|
| 521 |
+
def main():
|
| 522 |
+
# dustr init
|
| 523 |
+
silent = False
|
| 524 |
+
image_size = 224
|
| 525 |
+
weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
|
| 526 |
+
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
|
| 527 |
+
# dust3r will write the 3D model inside tmpdirname
|
| 528 |
+
# with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
|
| 529 |
+
tmpdirname = os.path.join('logs/user_object')
|
| 530 |
+
# remove the directory if it already exists
|
| 531 |
+
if os.path.exists(tmpdirname):
|
| 532 |
+
shutil.rmtree(tmpdirname)
|
| 533 |
+
os.makedirs(tmpdirname, exist_ok=True)
|
| 534 |
+
if not silent:
|
| 535 |
+
print('Outputing stuff in', tmpdirname)
|
| 536 |
+
|
| 537 |
+
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
|
| 538 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
|
| 539 |
+
|
| 540 |
+
generate_mvs = functools.partial(run_eschernet, tmpdirname)
|
| 541 |
+
|
| 542 |
+
_HEADER_ = '''
|
| 543 |
+
<h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
|
| 544 |
+
<b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
|
| 545 |
+
|
| 546 |
+
Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
|
| 547 |
+
|
| 548 |
+
<a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
|
| 549 |
+
<a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
|
| 550 |
+
<a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
|
| 551 |
+
|
| 552 |
+
<h4><b>Tips:</b></h4>
|
| 553 |
+
|
| 554 |
+
- Our model can take <b>any number input images</b>. The more images you provide, the better the results.
|
| 555 |
+
|
| 556 |
+
- Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
|
| 557 |
+
|
| 558 |
+
- The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
|
| 559 |
+
|
| 560 |
+
- The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
|
| 561 |
+
|
| 562 |
+
'''
|
| 563 |
+
|
| 564 |
+
_CITE_ = r"""
|
| 565 |
+
📝 <b>Citation</b>:
|
| 566 |
+
```bibtex
|
| 567 |
+
@article{kong2024eschernet,
|
| 568 |
+
title={EscherNet: A Generative Model for Scalable View Synthesis},
|
| 569 |
+
author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
|
| 570 |
+
journal={arXiv preprint arXiv:2402.03908},
|
| 571 |
+
year={2024}
|
| 572 |
+
}
|
| 573 |
+
```
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
with gr.Blocks() as demo:
|
| 577 |
+
gr.Markdown(_HEADER_)
|
| 578 |
+
mv_images = gr.State()
|
| 579 |
+
scene = gr.State(None)
|
| 580 |
+
eschernet_input = gr.State(None)
|
| 581 |
+
with gr.Row(variant="panel"):
|
| 582 |
+
# left column
|
| 583 |
+
with gr.Column():
|
| 584 |
+
with gr.Row():
|
| 585 |
+
input_image = gr.File(file_count="multiple")
|
| 586 |
+
# with gr.Row():
|
| 587 |
+
# # set the size of the window
|
| 588 |
+
# preview_image = gr.Gallery(label='Input Views', rows=1,
|
| 589 |
+
with gr.Row():
|
| 590 |
+
run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
|
| 591 |
+
with gr.Row():
|
| 592 |
+
processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
|
| 593 |
+
with gr.Row(variant="panel"):
|
| 594 |
+
# input examples under "examples" folder
|
| 595 |
+
gr.Examples(
|
| 596 |
+
examples=get_examples('examples'),
|
| 597 |
+
# examples=[
|
| 598 |
+
# [['examples/controller/frame000077.jpg', 'examples/controller/frame000032.jpg', 'examples/controller/frame000172.jpg']],
|
| 599 |
+
# [['examples/hairdryer/frame000081.jpg', 'examples/hairdryer/frame000162.jpg', 'examples/hairdryer/frame000003.jpg']],
|
| 600 |
+
# ],
|
| 601 |
+
inputs=[input_image],
|
| 602 |
+
label="Examples (click one set of images to start!)",
|
| 603 |
+
examples_per_page=20
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
# right column
|
| 611 |
+
with gr.Column():
|
| 612 |
+
|
| 613 |
+
with gr.Row():
|
| 614 |
+
outmodel = gr.Model3D()
|
| 615 |
+
|
| 616 |
+
with gr.Row():
|
| 617 |
+
gr.Markdown('''
|
| 618 |
+
<h4><b>Check if the pose and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
|
| 619 |
+
''')
|
| 620 |
+
|
| 621 |
+
with gr.Row():
|
| 622 |
+
with gr.Group():
|
| 623 |
+
do_remove_background = gr.Checkbox(
|
| 624 |
+
label="Remove Background", value=True
|
| 625 |
+
)
|
| 626 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
| 627 |
+
|
| 628 |
+
sample_steps = gr.Slider(
|
| 629 |
+
label="Sample Steps",
|
| 630 |
+
minimum=30,
|
| 631 |
+
maximum=75,
|
| 632 |
+
value=50,
|
| 633 |
+
step=5,
|
| 634 |
+
visible=False
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
nvs_num = gr.Slider(
|
| 638 |
+
label="Number of Novel Views",
|
| 639 |
+
minimum=5,
|
| 640 |
+
maximum=100,
|
| 641 |
+
value=30,
|
| 642 |
+
step=1
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
|
| 646 |
+
value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
|
| 647 |
+
|
| 648 |
+
with gr.Row():
|
| 649 |
+
gr.Markdown('''
|
| 650 |
+
<h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
|
| 651 |
+
''')
|
| 652 |
+
|
| 653 |
+
with gr.Row():
|
| 654 |
+
submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
|
| 655 |
+
|
| 656 |
+
with gr.Row():
|
| 657 |
+
# mv_show_images = gr.Image(
|
| 658 |
+
# label="Generated Multi-views",
|
| 659 |
+
# type="pil",
|
| 660 |
+
# width=379,
|
| 661 |
+
# interactive=False
|
| 662 |
+
# )
|
| 663 |
+
with gr.Column():
|
| 664 |
+
output_video = gr.Video(
|
| 665 |
+
label="video", format="mp4",
|
| 666 |
+
width=379,
|
| 667 |
+
autoplay=True,
|
| 668 |
+
interactive=False
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# with gr.Row():
|
| 672 |
+
# with gr.Tab("OBJ"):
|
| 673 |
+
# output_model_obj = gr.Model3D(
|
| 674 |
+
# label="Output Model (OBJ Format)",
|
| 675 |
+
# #width=768,
|
| 676 |
+
# interactive=False,
|
| 677 |
+
# )
|
| 678 |
+
# gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
|
| 679 |
+
# with gr.Tab("GLB"):
|
| 680 |
+
# output_model_glb = gr.Model3D(
|
| 681 |
+
# label="Output Model (GLB Format)",
|
| 682 |
+
# #width=768,
|
| 683 |
+
# interactive=False,
|
| 684 |
+
# )
|
| 685 |
+
# gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
| 686 |
+
|
| 687 |
+
with gr.Row():
|
| 688 |
+
gr.Markdown('''The novel views are generated on an archimedean spiral. You can download the video''')
|
| 689 |
+
|
| 690 |
+
gr.Markdown(_CITE_)
|
| 691 |
+
|
| 692 |
+
# set dust3r parameter invisible to be clean
|
| 693 |
+
with gr.Column():
|
| 694 |
+
with gr.Row():
|
| 695 |
+
schedule = gr.Dropdown(["linear", "cosine"],
|
| 696 |
+
value='linear', label="schedule", info="For global alignment!", visible=False)
|
| 697 |
+
niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
|
| 698 |
+
label="num_iterations", info="For global alignment!", visible=False)
|
| 699 |
+
scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
|
| 700 |
+
value='complete', label="Scenegraph",
|
| 701 |
+
info="Define how to make pairs",
|
| 702 |
+
interactive=True, visible=False)
|
| 703 |
+
same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
|
| 704 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
|
| 705 |
+
minimum=1, maximum=1, step=1, visible=False)
|
| 706 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
| 707 |
+
|
| 708 |
+
with gr.Row():
|
| 709 |
+
# adjust the confidence threshold
|
| 710 |
+
min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
|
| 711 |
+
# adjust the camera size in the output pointcloud
|
| 712 |
+
cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
|
| 713 |
+
with gr.Row():
|
| 714 |
+
as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
|
| 715 |
+
# two post process implemented
|
| 716 |
+
mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
|
| 717 |
+
clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
|
| 718 |
+
transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
|
| 719 |
+
|
| 720 |
+
# events
|
| 721 |
+
# scenegraph_type.change(set_scenegraph_options,
|
| 722 |
+
# inputs=[input_image, winsize, refid, scenegraph_type],
|
| 723 |
+
# outputs=[winsize, refid])
|
| 724 |
+
input_image.change(set_scenegraph_options,
|
| 725 |
+
inputs=[input_image, winsize, refid, scenegraph_type],
|
| 726 |
+
outputs=[winsize, refid])
|
| 727 |
+
# min_conf_thr.release(fn=model_from_scene_fun,
|
| 728 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 729 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
| 730 |
+
# outputs=outmodel)
|
| 731 |
+
# cam_size.change(fn=model_from_scene_fun,
|
| 732 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 733 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
| 734 |
+
# outputs=outmodel)
|
| 735 |
+
# as_pointcloud.change(fn=model_from_scene_fun,
|
| 736 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 737 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
| 738 |
+
# outputs=outmodel)
|
| 739 |
+
# mask_sky.change(fn=model_from_scene_fun,
|
| 740 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 741 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
| 742 |
+
# outputs=outmodel)
|
| 743 |
+
# clean_depth.change(fn=model_from_scene_fun,
|
| 744 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 745 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
| 746 |
+
# outputs=outmodel)
|
| 747 |
+
# transparent_cams.change(model_from_scene_fun,
|
| 748 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 749 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
| 750 |
+
# outputs=outmodel)
|
| 751 |
+
run_dust3r.click(fn=recon_fun,
|
| 752 |
+
inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
|
| 753 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
| 754 |
+
scenegraph_type, winsize, refid, same_focals],
|
| 755 |
+
outputs=[scene, outmodel, processed_image, eschernet_input])
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
# events
|
| 759 |
+
# preview images on input change
|
| 760 |
+
input_image.change(fn=preview_input,
|
| 761 |
+
inputs=[input_image],
|
| 762 |
+
outputs=[processed_image])
|
| 763 |
+
|
| 764 |
+
submit.click(fn=generate_mvs,
|
| 765 |
+
inputs=[eschernet_input, sample_steps, sample_seed,
|
| 766 |
+
nvs_num, nvs_mode],
|
| 767 |
+
outputs=[mv_images, output_video],
|
| 768 |
+
)#.success(
|
| 769 |
+
# # fn=make3d,
|
| 770 |
+
# # inputs=[mv_images],
|
| 771 |
+
# # outputs=[output_video, output_model_obj, output_model_glb]
|
| 772 |
+
# # )
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
demo.queue(max_size=10)
|
| 777 |
+
demo.launch(share=True, server_name="0.0.0.0", server_port=None)
|
| 778 |
+
|
| 779 |
+
if __name__ == '__main__':
|
| 780 |
+
main()
|