Spaces:
Configuration error
Configuration error
jiaweir
commited on
Commit
·
5b9bbe2
1
Parent(s):
fc94f83
optimize
Browse files- app.py +184 -19
- configs/4d_demo.yaml +1 -1
- lgm/infer_demo.py +197 -0
- main_4d_demo.py +616 -0
app.py
CHANGED
|
@@ -7,6 +7,26 @@ import numpy
|
|
| 7 |
import hashlib
|
| 8 |
import shlex
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import spaces
|
| 11 |
|
| 12 |
|
|
@@ -27,45 +47,179 @@ function refresh() {
|
|
| 27 |
}
|
| 28 |
"""
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# check if there is a picture uploaded or selected
|
| 31 |
def check_img_input(control_image):
|
| 32 |
if control_image is None:
|
| 33 |
raise gr.Error("Please select or upload an input image")
|
| 34 |
|
| 35 |
# check if there is a picture uploaded or selected
|
| 36 |
-
def
|
| 37 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
| 38 |
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
|
| 39 |
raise gr.Error("Please generate a video first")
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
@spaces.GPU()
|
| 43 |
-
def
|
| 44 |
if not os.path.exists('tmp_data'):
|
| 45 |
os.makedirs('tmp_data')
|
| 46 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
| 47 |
-
if
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
# stage 1
|
| 58 |
-
subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
|
| 59 |
-
|
| 60 |
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
|
| 61 |
return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
|
| 65 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
| 66 |
|
| 67 |
# stage 2
|
| 68 |
-
subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
|
|
|
|
| 69 |
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
|
| 70 |
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
|
| 71 |
# return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
|
|
@@ -83,7 +237,7 @@ if __name__ == "__main__":
|
|
| 83 |
</div>
|
| 84 |
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
|
| 85 |
'''
|
| 86 |
-
_IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above),
|
| 87 |
|
| 88 |
# load images in 'data' folder as examples
|
| 89 |
example_folder = os.path.join(os.path.dirname(__file__), 'data')
|
|
@@ -104,7 +258,8 @@ if __name__ == "__main__":
|
|
| 104 |
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
|
| 105 |
|
| 106 |
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
|
| 107 |
-
seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed')
|
|
|
|
| 108 |
gr.Markdown(
|
| 109 |
"random seed for video generation.")
|
| 110 |
|
|
@@ -120,20 +275,30 @@ if __name__ == "__main__":
|
|
| 120 |
examples_per_page=40
|
| 121 |
)
|
| 122 |
img_run_btn = gr.Button("Generate Video")
|
|
|
|
| 123 |
fourd_run_btn = gr.Button("Generate 4D")
|
| 124 |
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
|
| 125 |
|
| 126 |
with gr.Column(scale=5):
|
| 127 |
-
|
|
|
|
| 128 |
obj4d = Model4DGS(label="4D Model", height=500, fps=14)
|
| 129 |
|
| 130 |
-
|
|
|
|
| 131 |
inputs=[image_block,
|
| 132 |
preprocess_chk,
|
| 133 |
seed_slider],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
outputs=[
|
| 135 |
obj3d])
|
| 136 |
-
fourd_run_btn.click(
|
| 137 |
|
| 138 |
# demo.queue().launch(share=True)
|
| 139 |
demo.queue(max_size=10) # <-- Sets up a queue with default parameters
|
|
|
|
| 7 |
import hashlib
|
| 8 |
import shlex
|
| 9 |
|
| 10 |
+
import rembg
|
| 11 |
+
import glob
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
from diffusers import StableVideoDiffusionPipeline
|
| 15 |
+
from scripts.gen_vid import *
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
sys.path.append('lgm')
|
| 19 |
+
from safetensors.torch import load_file
|
| 20 |
+
from kiui.cam import orbit_camera
|
| 21 |
+
from core.options import config_defaults, Options
|
| 22 |
+
from core.models import LGM
|
| 23 |
+
from mvdream.pipeline_mvdream import MVDreamPipeline
|
| 24 |
+
from infer_demo import process as process_lgm
|
| 25 |
+
|
| 26 |
+
from main_4d_demo import process as process_dg4d
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
import spaces
|
| 31 |
|
| 32 |
|
|
|
|
| 47 |
}
|
| 48 |
"""
|
| 49 |
|
| 50 |
+
|
| 51 |
+
device = torch.device('cuda')
|
| 52 |
+
# device = torch.device('cpu')
|
| 53 |
+
|
| 54 |
+
session = rembg.new_session(model_name='u2net')
|
| 55 |
+
|
| 56 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained(
|
| 57 |
+
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
|
| 58 |
+
)
|
| 59 |
+
pipe.to(device)
|
| 60 |
+
|
| 61 |
+
opt = config_defaults['big']
|
| 62 |
+
opt.resume = ckpt_path
|
| 63 |
+
# model
|
| 64 |
+
model = LGM(opt)
|
| 65 |
+
|
| 66 |
+
# resume pretrained checkpoint
|
| 67 |
+
if opt.resume is not None:
|
| 68 |
+
if opt.resume.endswith('safetensors'):
|
| 69 |
+
ckpt = load_file(opt.resume, device='cpu')
|
| 70 |
+
else:
|
| 71 |
+
ckpt = torch.load(opt.resume, map_location='cpu')
|
| 72 |
+
model.load_state_dict(ckpt, strict=False)
|
| 73 |
+
print(f'[INFO] Loaded checkpoint from {opt.resume}')
|
| 74 |
+
else:
|
| 75 |
+
print(f'[WARN] model randomly initialized, are you sure?')
|
| 76 |
+
|
| 77 |
+
# device
|
| 78 |
+
model = model.half().to(device)
|
| 79 |
+
model.eval()
|
| 80 |
+
rays_embeddings = model.prepare_default_rays(device)
|
| 81 |
+
|
| 82 |
+
# load image dream
|
| 83 |
+
pipe_mvdream = MVDreamPipeline.from_pretrained(
|
| 84 |
+
"ashawkey/imagedream-ipmv-diffusers", # remote weights
|
| 85 |
+
torch_dtype=torch.float16,
|
| 86 |
+
trust_remote_code=True,
|
| 87 |
+
# local_files_only=True,
|
| 88 |
+
)
|
| 89 |
+
pipe_mvdream = pipe_mvdream.to(device)
|
| 90 |
+
|
| 91 |
+
from guidance.zero123_utils import Zero123
|
| 92 |
+
guidance_zero123 = Zero123(device, model_key='ashawkey/stable-zero123-diffusers')
|
| 93 |
+
|
| 94 |
+
def preprocess(path, recenter=True, size=256, border_ratio=0.2):
|
| 95 |
+
files = [path]
|
| 96 |
+
out_dir = os.path.dirname(path)
|
| 97 |
+
|
| 98 |
+
for file in files:
|
| 99 |
+
|
| 100 |
+
out_base = os.path.basename(file).split('.')[0]
|
| 101 |
+
out_rgba = os.path.join(out_dir, out_base + '_rgba.png')
|
| 102 |
+
|
| 103 |
+
# load image
|
| 104 |
+
print(f'[INFO] loading image {file}...')
|
| 105 |
+
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
|
| 106 |
+
|
| 107 |
+
# carve background
|
| 108 |
+
print(f'[INFO] background removal...')
|
| 109 |
+
carved_image = rembg.remove(image, session=session) # [H, W, 4]
|
| 110 |
+
mask = carved_image[..., -1] > 0
|
| 111 |
+
|
| 112 |
+
# recenter
|
| 113 |
+
if recenter:
|
| 114 |
+
print(f'[INFO] recenter...')
|
| 115 |
+
final_rgba = np.zeros((size, size, 4), dtype=np.uint8)
|
| 116 |
+
|
| 117 |
+
coords = np.nonzero(mask)
|
| 118 |
+
x_min, x_max = coords[0].min(), coords[0].max()
|
| 119 |
+
y_min, y_max = coords[1].min(), coords[1].max()
|
| 120 |
+
h = x_max - x_min
|
| 121 |
+
w = y_max - y_min
|
| 122 |
+
desired_size = int(size * (1 - border_ratio))
|
| 123 |
+
scale = desired_size / max(h, w)
|
| 124 |
+
h2 = int(h * scale)
|
| 125 |
+
w2 = int(w * scale)
|
| 126 |
+
x2_min = (size - h2) // 2
|
| 127 |
+
x2_max = x2_min + h2
|
| 128 |
+
y2_min = (size - w2) // 2
|
| 129 |
+
y2_max = y2_min + w2
|
| 130 |
+
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
|
| 131 |
+
|
| 132 |
+
else:
|
| 133 |
+
final_rgba = carved_image
|
| 134 |
+
|
| 135 |
+
# write image
|
| 136 |
+
cv2.imwrite(out_rgba, final_rgba)
|
| 137 |
+
|
| 138 |
+
def gen_vid(input_path, seed, bg='white'):
|
| 139 |
+
name = input_path.split('/')[-1].split('.')[0]
|
| 140 |
+
input_dir = os.path.dirname(input_path)
|
| 141 |
+
height, width = 512, 512
|
| 142 |
+
|
| 143 |
+
image = load_image(input_path, width, height, bg)
|
| 144 |
+
|
| 145 |
+
generator = torch.manual_seed(seed)
|
| 146 |
+
# frames = pipe(image, height, width, decode_chunk_size=2, generator=generator).frames[0]
|
| 147 |
+
frames = pipe(image, height, width, generator=generator).frames[0]
|
| 148 |
+
|
| 149 |
+
imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7)
|
| 150 |
+
os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True)
|
| 151 |
+
for idx, img in enumerate(frames):
|
| 152 |
+
img.save(f"{input_dir}/{name}_frames/{idx:03}.png")
|
| 153 |
+
|
| 154 |
# check if there is a picture uploaded or selected
|
| 155 |
def check_img_input(control_image):
|
| 156 |
if control_image is None:
|
| 157 |
raise gr.Error("Please select or upload an input image")
|
| 158 |
|
| 159 |
# check if there is a picture uploaded or selected
|
| 160 |
+
def check_video_3d_input(image_block: Image.Image):
|
| 161 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
| 162 |
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
|
| 163 |
raise gr.Error("Please generate a video first")
|
| 164 |
+
if not os.path.exists(os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')):
|
| 165 |
+
raise gr.Error("Please generate a 3D first")
|
| 166 |
+
|
| 167 |
|
| 168 |
|
| 169 |
@spaces.GPU()
|
| 170 |
+
def optimize_stage_0(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
|
| 171 |
if not os.path.exists('tmp_data'):
|
| 172 |
os.makedirs('tmp_data')
|
| 173 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
| 174 |
+
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')):
|
| 175 |
+
if preprocess_chk:
|
| 176 |
+
# save image to a designated path
|
| 177 |
+
image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
|
| 178 |
|
| 179 |
+
# preprocess image
|
| 180 |
+
# print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
|
| 181 |
+
# subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
|
| 182 |
+
preprocess(os.path.join("tmp_data", f"{img_hash}.png"))
|
| 183 |
+
else:
|
| 184 |
+
image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
|
| 185 |
|
| 186 |
# stage 1
|
| 187 |
+
# subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
|
| 188 |
+
gen_vid(f'tmp_data/{img_hash}_rgba.png', seed_slider)
|
| 189 |
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
|
| 190 |
return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
|
| 191 |
|
| 192 |
+
|
| 193 |
+
@spaces.GPU()
|
| 194 |
+
def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
|
| 195 |
+
if not os.path.exists('tmp_data'):
|
| 196 |
+
os.makedirs('tmp_data')
|
| 197 |
+
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
| 198 |
+
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')):
|
| 199 |
+
if preprocess_chk:
|
| 200 |
+
# save image to a designated path
|
| 201 |
+
image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
|
| 202 |
+
|
| 203 |
+
# preprocess image
|
| 204 |
+
# print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
|
| 205 |
+
# subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
|
| 206 |
+
preprocess(os.path.join("tmp_data", f"{img_hash}.png"))
|
| 207 |
+
else:
|
| 208 |
+
image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
|
| 209 |
+
|
| 210 |
+
# stage 1
|
| 211 |
+
# subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
|
| 212 |
+
process_lgm(opt, f'tmp_data/{img_hash}_rgba.png', pipe_mvdream, model, rays_embeddings)
|
| 213 |
+
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
|
| 214 |
+
return os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')
|
| 215 |
+
|
| 216 |
+
@spaces.GPU(duration=120)
|
| 217 |
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
|
| 218 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
| 219 |
|
| 220 |
# stage 2
|
| 221 |
+
# subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
|
| 222 |
+
process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123)
|
| 223 |
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
|
| 224 |
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
|
| 225 |
# return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
|
|
|
|
| 237 |
</div>
|
| 238 |
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
|
| 239 |
'''
|
| 240 |
+
_IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), click **Generate Video** and **Generate 3D**. Finally, click **Generate 4D**."
|
| 241 |
|
| 242 |
# load images in 'data' folder as examples
|
| 243 |
example_folder = os.path.join(os.path.dirname(__file__), 'data')
|
|
|
|
| 258 |
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
|
| 259 |
|
| 260 |
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
|
| 261 |
+
seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (Video)')
|
| 262 |
+
seed_slider2 = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (3D)')
|
| 263 |
gr.Markdown(
|
| 264 |
"random seed for video generation.")
|
| 265 |
|
|
|
|
| 275 |
examples_per_page=40
|
| 276 |
)
|
| 277 |
img_run_btn = gr.Button("Generate Video")
|
| 278 |
+
threed_run_btn = gr.Button("Generate 3D")
|
| 279 |
fourd_run_btn = gr.Button("Generate 4D")
|
| 280 |
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
|
| 281 |
|
| 282 |
with gr.Column(scale=5):
|
| 283 |
+
dirving_video = gr.Video(label="video",height=290)
|
| 284 |
+
obj3d = gr.Video(label="3D Model",height=290)
|
| 285 |
obj4d = Model4DGS(label="4D Model", height=500, fps=14)
|
| 286 |
|
| 287 |
+
|
| 288 |
+
img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_0,
|
| 289 |
inputs=[image_block,
|
| 290 |
preprocess_chk,
|
| 291 |
seed_slider],
|
| 292 |
+
outputs=[
|
| 293 |
+
dirving_video])
|
| 294 |
+
|
| 295 |
+
threed_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
|
| 296 |
+
inputs=[image_block,
|
| 297 |
+
preprocess_chk,
|
| 298 |
+
seed_slider2],
|
| 299 |
outputs=[
|
| 300 |
obj3d])
|
| 301 |
+
fourd_run_btn.click(check_video_3d_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])
|
| 302 |
|
| 303 |
# demo.queue().launch(share=True)
|
| 304 |
demo.queue(max_size=10) # <-- Sets up a queue with default parameters
|
configs/4d_demo.yaml
CHANGED
|
@@ -30,7 +30,7 @@ lambda_svd: 0
|
|
| 30 |
# training batch size per iter
|
| 31 |
batch_size: 7
|
| 32 |
# training iterations for stage 1
|
| 33 |
-
iters:
|
| 34 |
# training iterations for stage 2
|
| 35 |
iters_refine: 50
|
| 36 |
# training camera radius
|
|
|
|
| 30 |
# training batch size per iter
|
| 31 |
batch_size: 7
|
| 32 |
# training iterations for stage 1
|
| 33 |
+
iters: 300
|
| 34 |
# training iterations for stage 2
|
| 35 |
iters_refine: 50
|
| 36 |
# training camera radius
|
lgm/infer_demo.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import tyro
|
| 4 |
+
import glob
|
| 5 |
+
import imageio
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tqdm
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchvision.transforms.functional as TF
|
| 12 |
+
from safetensors.torch import load_file
|
| 13 |
+
|
| 14 |
+
import kiui
|
| 15 |
+
from kiui.op import recenter
|
| 16 |
+
from kiui.cam import orbit_camera
|
| 17 |
+
|
| 18 |
+
from core.options import AllConfigs, Options
|
| 19 |
+
from core.models import LGM
|
| 20 |
+
from mvdream.pipeline_mvdream import MVDreamPipeline
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 24 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 25 |
+
|
| 26 |
+
# opt = tyro.cli(AllConfigs)
|
| 27 |
+
|
| 28 |
+
# # model
|
| 29 |
+
# model = LGM(opt)
|
| 30 |
+
|
| 31 |
+
# # resume pretrained checkpoint
|
| 32 |
+
# if opt.resume is not None:
|
| 33 |
+
# if opt.resume.endswith('safetensors'):
|
| 34 |
+
# ckpt = load_file(opt.resume, device='cpu')
|
| 35 |
+
# else:
|
| 36 |
+
# ckpt = torch.load(opt.resume, map_location='cpu')
|
| 37 |
+
# model.load_state_dict(ckpt, strict=False)
|
| 38 |
+
# print(f'[INFO] Loaded checkpoint from {opt.resume}')
|
| 39 |
+
# else:
|
| 40 |
+
# print(f'[WARN] model randomly initialized, are you sure?')
|
| 41 |
+
|
| 42 |
+
# # device
|
| 43 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 44 |
+
# model = model.half().to(device)
|
| 45 |
+
# model.eval()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# process function
|
| 50 |
+
def process(opt: Options, path, pipe, model, rays_embeddings):
|
| 51 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 52 |
+
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
| 53 |
+
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
|
| 54 |
+
proj_matrix[0, 0] = 1 / tan_half_fov
|
| 55 |
+
proj_matrix[1, 1] = 1 / tan_half_fov
|
| 56 |
+
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
| 57 |
+
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
| 58 |
+
proj_matrix[2, 3] = 1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
| 62 |
+
print(f'[INFO] Processing {path} --> {name}')
|
| 63 |
+
os.makedirs('vis_data', exist_ok=True)
|
| 64 |
+
os.makedirs('logs', exist_ok=True)
|
| 65 |
+
|
| 66 |
+
image = kiui.read_image(path, mode='uint8')
|
| 67 |
+
|
| 68 |
+
# generate mv
|
| 69 |
+
image = image.astype(np.float32) / 255.0
|
| 70 |
+
|
| 71 |
+
# rgba to rgb white bg
|
| 72 |
+
if image.shape[-1] == 4:
|
| 73 |
+
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
|
| 74 |
+
|
| 75 |
+
mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0)
|
| 76 |
+
mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
|
| 77 |
+
|
| 78 |
+
# generate gaussians
|
| 79 |
+
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
| 80 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
| 81 |
+
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
| 82 |
+
|
| 83 |
+
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
|
| 84 |
+
|
| 85 |
+
with torch.inference_mode():
|
| 86 |
+
############## align azimuth #####################
|
| 87 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 88 |
+
# generate gaussians
|
| 89 |
+
gaussians = model.forward_gaussians(input_image)
|
| 90 |
+
|
| 91 |
+
best_azi = 0
|
| 92 |
+
best_diff = 1e8
|
| 93 |
+
for v, azi in enumerate(np.arange(-180, 180, 1)):
|
| 94 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
| 95 |
+
|
| 96 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
| 97 |
+
|
| 98 |
+
# cameras needed by gaussian rasterizer
|
| 99 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
| 100 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
| 101 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
| 102 |
+
|
| 103 |
+
# scale = min(azi / 360, 1)
|
| 104 |
+
scale = 1
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
|
| 108 |
+
rendered_image = result['image']
|
| 109 |
+
|
| 110 |
+
rendered_image = rendered_image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy()
|
| 111 |
+
rendered_image = cv2.resize(rendered_image, (image.shape[0], image.shape[1]), interpolation=cv2.INTER_AREA)
|
| 112 |
+
|
| 113 |
+
diff = np.mean((rendered_image- image) ** 2)
|
| 114 |
+
|
| 115 |
+
if diff < best_diff:
|
| 116 |
+
best_diff = diff
|
| 117 |
+
best_azi = azi
|
| 118 |
+
print("Best aligned azimuth: ", best_azi)
|
| 119 |
+
|
| 120 |
+
mv_image = []
|
| 121 |
+
for v, azi in enumerate([0, 90, 180, 270]):
|
| 122 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
| 123 |
+
|
| 124 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
| 125 |
+
|
| 126 |
+
# cameras needed by gaussian rasterizer
|
| 127 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
| 128 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
| 129 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
| 130 |
+
|
| 131 |
+
# scale = min(azi / 360, 1)
|
| 132 |
+
scale = 1
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
|
| 136 |
+
rendered_image = result['image']
|
| 137 |
+
rendered_image = rendered_image.squeeze(1)
|
| 138 |
+
rendered_image = F.interpolate(rendered_image, (256, 256))
|
| 139 |
+
rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
|
| 140 |
+
mv_image.append(rendered_image)
|
| 141 |
+
mv_image = np.concatenate(mv_image, axis=0)
|
| 142 |
+
|
| 143 |
+
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
| 144 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
| 145 |
+
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
| 146 |
+
|
| 147 |
+
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
|
| 148 |
+
|
| 149 |
+
################################
|
| 150 |
+
|
| 151 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 152 |
+
# generate gaussians
|
| 153 |
+
gaussians = model.forward_gaussians(input_image)
|
| 154 |
+
|
| 155 |
+
# save gaussians
|
| 156 |
+
model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
|
| 157 |
+
|
| 158 |
+
# render 360 video
|
| 159 |
+
images = []
|
| 160 |
+
elevation = 0
|
| 161 |
+
|
| 162 |
+
if opt.fancy_video:
|
| 163 |
+
|
| 164 |
+
azimuth = np.arange(0, 720, 4, dtype=np.int32)
|
| 165 |
+
for azi in tqdm.tqdm(azimuth):
|
| 166 |
+
|
| 167 |
+
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
| 168 |
+
|
| 169 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
| 170 |
+
|
| 171 |
+
# cameras needed by gaussian rasterizer
|
| 172 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
| 173 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
| 174 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
| 175 |
+
|
| 176 |
+
scale = min(azi / 360, 1)
|
| 177 |
+
|
| 178 |
+
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
|
| 179 |
+
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
| 180 |
+
else:
|
| 181 |
+
azimuth = np.arange(0, 360, 2, dtype=np.int32)
|
| 182 |
+
for azi in tqdm.tqdm(azimuth):
|
| 183 |
+
|
| 184 |
+
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
| 185 |
+
|
| 186 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
| 187 |
+
|
| 188 |
+
# cameras needed by gaussian rasterizer
|
| 189 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
| 190 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
| 191 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
| 192 |
+
|
| 193 |
+
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
|
| 194 |
+
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
| 195 |
+
|
| 196 |
+
images = np.concatenate(images, axis=0)
|
| 197 |
+
imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)
|
main_4d_demo.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import time
|
| 4 |
+
import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
import rembg
|
| 11 |
+
|
| 12 |
+
from cam_utils import orbit_camera, OrbitCamera
|
| 13 |
+
from gs_renderer_4d import Renderer, MiniCam
|
| 14 |
+
|
| 15 |
+
from grid_put import mipmap_linear_grid_put_2d
|
| 16 |
+
import imageio
|
| 17 |
+
|
| 18 |
+
import copy
|
| 19 |
+
from omegaconf import OmegaConf
|
| 20 |
+
|
| 21 |
+
class GUI:
|
| 22 |
+
def __init__(self, opt, guidance_zero123):
|
| 23 |
+
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
| 24 |
+
self.gui = opt.gui # enable gui
|
| 25 |
+
self.W = opt.W
|
| 26 |
+
self.H = opt.H
|
| 27 |
+
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
|
| 28 |
+
|
| 29 |
+
self.mode = "image"
|
| 30 |
+
# self.seed = "random"
|
| 31 |
+
self.seed = 888
|
| 32 |
+
|
| 33 |
+
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
|
| 34 |
+
self.need_update = True # update buffer_image
|
| 35 |
+
|
| 36 |
+
# models
|
| 37 |
+
self.device = torch.device("cuda")
|
| 38 |
+
self.bg_remover = None
|
| 39 |
+
|
| 40 |
+
self.guidance_sd = None
|
| 41 |
+
self.guidance_zero123 = guidance_zero123
|
| 42 |
+
self.guidance_svd = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
self.enable_sd = False
|
| 46 |
+
self.enable_zero123 = False
|
| 47 |
+
self.enable_svd = False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# renderer
|
| 51 |
+
self.renderer = Renderer(self.opt, sh_degree=self.opt.sh_degree)
|
| 52 |
+
self.gaussain_scale_factor = 1
|
| 53 |
+
|
| 54 |
+
# input image
|
| 55 |
+
self.input_img = None
|
| 56 |
+
self.input_mask = None
|
| 57 |
+
self.input_img_torch = None
|
| 58 |
+
self.input_mask_torch = None
|
| 59 |
+
self.overlay_input_img = False
|
| 60 |
+
self.overlay_input_img_ratio = 0.5
|
| 61 |
+
|
| 62 |
+
self.input_img_list = None
|
| 63 |
+
self.input_mask_list = None
|
| 64 |
+
self.input_img_torch_list = None
|
| 65 |
+
self.input_mask_torch_list = None
|
| 66 |
+
|
| 67 |
+
# input text
|
| 68 |
+
self.prompt = ""
|
| 69 |
+
self.negative_prompt = ""
|
| 70 |
+
|
| 71 |
+
# training stuff
|
| 72 |
+
self.training = False
|
| 73 |
+
self.optimizer = None
|
| 74 |
+
self.step = 0
|
| 75 |
+
self.train_steps = 1 # steps per rendering loop
|
| 76 |
+
|
| 77 |
+
# load input data from cmdline
|
| 78 |
+
if self.opt.input is not None: # True
|
| 79 |
+
self.load_input(self.opt.input) # load imgs, if has bg, then rm bg; or just load imgs
|
| 80 |
+
|
| 81 |
+
# override prompt from cmdline
|
| 82 |
+
if self.opt.prompt is not None: # None
|
| 83 |
+
self.prompt = self.opt.prompt
|
| 84 |
+
|
| 85 |
+
# override if provide a checkpoint
|
| 86 |
+
if self.opt.load is not None: # not None
|
| 87 |
+
self.renderer.initialize(self.opt.load)
|
| 88 |
+
# self.renderer.gaussians.load_model(opt.outdir, opt.save_path)
|
| 89 |
+
else:
|
| 90 |
+
# initialize gaussians to a blob
|
| 91 |
+
self.renderer.initialize(num_pts=self.opt.num_pts)
|
| 92 |
+
|
| 93 |
+
self.seed_everything()
|
| 94 |
+
|
| 95 |
+
def seed_everything(self):
|
| 96 |
+
try:
|
| 97 |
+
seed = int(self.seed)
|
| 98 |
+
except:
|
| 99 |
+
seed = np.random.randint(0, 1000000)
|
| 100 |
+
|
| 101 |
+
print(f'Seed: {seed:d}')
|
| 102 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 103 |
+
np.random.seed(seed)
|
| 104 |
+
torch.manual_seed(seed)
|
| 105 |
+
torch.cuda.manual_seed(seed)
|
| 106 |
+
torch.backends.cudnn.deterministic = True
|
| 107 |
+
torch.backends.cudnn.benchmark = True
|
| 108 |
+
|
| 109 |
+
self.last_seed = seed
|
| 110 |
+
|
| 111 |
+
def prepare_train(self):
|
| 112 |
+
|
| 113 |
+
self.step = 0
|
| 114 |
+
|
| 115 |
+
# setup training
|
| 116 |
+
self.renderer.gaussians.training_setup(self.opt)
|
| 117 |
+
|
| 118 |
+
# # do not do progressive sh-level
|
| 119 |
+
self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree
|
| 120 |
+
self.optimizer = self.renderer.gaussians.optimizer
|
| 121 |
+
|
| 122 |
+
# default camera
|
| 123 |
+
if self.opt.mvdream or self.opt.imagedream:
|
| 124 |
+
# the second view is the front view for mvdream/imagedream.
|
| 125 |
+
pose = orbit_camera(self.opt.elevation, 90, self.opt.radius)
|
| 126 |
+
else:
|
| 127 |
+
pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
|
| 128 |
+
self.fixed_cam = MiniCam(
|
| 129 |
+
pose,
|
| 130 |
+
self.opt.ref_size,
|
| 131 |
+
self.opt.ref_size,
|
| 132 |
+
self.cam.fovy,
|
| 133 |
+
self.cam.fovx,
|
| 134 |
+
self.cam.near,
|
| 135 |
+
self.cam.far,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.enable_sd = self.opt.lambda_sd > 0
|
| 139 |
+
self.enable_zero123 = self.opt.lambda_zero123 > 0
|
| 140 |
+
self.enable_svd = self.opt.lambda_svd > 0 and self.input_img is not None
|
| 141 |
+
|
| 142 |
+
# lazy load guidance model
|
| 143 |
+
if self.guidance_sd is None and self.enable_sd:
|
| 144 |
+
if self.opt.mvdream:
|
| 145 |
+
print(f"[INFO] loading MVDream...")
|
| 146 |
+
from guidance.mvdream_utils import MVDream
|
| 147 |
+
self.guidance_sd = MVDream(self.device)
|
| 148 |
+
print(f"[INFO] loaded MVDream!")
|
| 149 |
+
elif self.opt.imagedream:
|
| 150 |
+
print(f"[INFO] loading ImageDream...")
|
| 151 |
+
from guidance.imagedream_utils import ImageDream
|
| 152 |
+
self.guidance_sd = ImageDream(self.device)
|
| 153 |
+
print(f"[INFO] loaded ImageDream!")
|
| 154 |
+
else:
|
| 155 |
+
print(f"[INFO] loading SD...")
|
| 156 |
+
from guidance.sd_utils import StableDiffusion
|
| 157 |
+
self.guidance_sd = StableDiffusion(self.device)
|
| 158 |
+
print(f"[INFO] loaded SD!")
|
| 159 |
+
|
| 160 |
+
if self.guidance_zero123 is None and self.enable_zero123:
|
| 161 |
+
print(f"[INFO] loading zero123...")
|
| 162 |
+
from guidance.zero123_utils import Zero123
|
| 163 |
+
if self.opt.stable_zero123:
|
| 164 |
+
self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/stable-zero123-diffusers')
|
| 165 |
+
else:
|
| 166 |
+
self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers')
|
| 167 |
+
print(f"[INFO] loaded zero123!")
|
| 168 |
+
|
| 169 |
+
if self.guidance_svd is None and self.enable_svd: # False
|
| 170 |
+
print(f"[INFO] loading SVD...")
|
| 171 |
+
from guidance.svd_utils import StableVideoDiffusion
|
| 172 |
+
self.guidance_svd = StableVideoDiffusion(self.device)
|
| 173 |
+
print(f"[INFO] loaded SVD!")
|
| 174 |
+
|
| 175 |
+
# input image
|
| 176 |
+
if self.input_img is not None:
|
| 177 |
+
self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
|
| 178 |
+
self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
|
| 179 |
+
|
| 180 |
+
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
|
| 181 |
+
self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
|
| 182 |
+
|
| 183 |
+
if self.input_img_list is not None:
|
| 184 |
+
self.input_img_torch_list = [torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).to(self.device) for input_img in self.input_img_list]
|
| 185 |
+
self.input_img_torch_list = [F.interpolate(input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) for input_img_torch in self.input_img_torch_list]
|
| 186 |
+
|
| 187 |
+
self.input_mask_torch_list = [torch.from_numpy(input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device) for input_mask in self.input_mask_list]
|
| 188 |
+
self.input_mask_torch_list = [F.interpolate(input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) for input_mask_torch in self.input_mask_torch_list]
|
| 189 |
+
# prepare embeddings
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
|
| 192 |
+
if self.enable_sd:
|
| 193 |
+
if self.opt.imagedream:
|
| 194 |
+
img_pos_list, img_neg_list, ip_pos_list, ip_neg_list, emb_pos_list, emb_neg_list = [], [], [], [], [], []
|
| 195 |
+
for _ in range(self.opt.n_views):
|
| 196 |
+
for input_img_torch in self.input_img_torch_list:
|
| 197 |
+
img_pos, img_neg, ip_pos, ip_neg, emb_pos, emb_neg = self.guidance_sd.get_image_text_embeds(input_img_torch, [self.prompt], [self.negative_prompt])
|
| 198 |
+
img_pos_list.append(img_pos)
|
| 199 |
+
img_neg_list.append(img_neg)
|
| 200 |
+
ip_pos_list.append(ip_pos)
|
| 201 |
+
ip_neg_list.append(ip_neg)
|
| 202 |
+
emb_pos_list.append(emb_pos)
|
| 203 |
+
emb_neg_list.append(emb_neg)
|
| 204 |
+
self.guidance_sd.image_embeddings['pos'] = torch.cat(img_pos_list, 0)
|
| 205 |
+
self.guidance_sd.image_embeddings['neg'] = torch.cat(img_pos_list, 0)
|
| 206 |
+
self.guidance_sd.image_embeddings['ip_img'] = torch.cat(ip_pos_list, 0)
|
| 207 |
+
self.guidance_sd.image_embeddings['neg_ip_img'] = torch.cat(ip_neg_list, 0)
|
| 208 |
+
self.guidance_sd.embeddings['pos'] = torch.cat(emb_pos_list, 0)
|
| 209 |
+
self.guidance_sd.embeddings['neg'] = torch.cat(emb_neg_list, 0)
|
| 210 |
+
else:
|
| 211 |
+
self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
|
| 212 |
+
|
| 213 |
+
if self.enable_zero123:
|
| 214 |
+
c_list, v_list = [], []
|
| 215 |
+
for _ in range(self.opt.n_views):
|
| 216 |
+
for input_img_torch in self.input_img_torch_list:
|
| 217 |
+
c, v = self.guidance_zero123.get_img_embeds(input_img_torch)
|
| 218 |
+
c_list.append(c)
|
| 219 |
+
v_list.append(v)
|
| 220 |
+
self.guidance_zero123.embeddings = [torch.cat(c_list, 0), torch.cat(v_list, 0)]
|
| 221 |
+
|
| 222 |
+
if self.enable_svd:
|
| 223 |
+
self.guidance_svd.get_img_embeds(self.input_img)
|
| 224 |
+
|
| 225 |
+
def train_step(self):
|
| 226 |
+
starter = torch.cuda.Event(enable_timing=True)
|
| 227 |
+
ender = torch.cuda.Event(enable_timing=True)
|
| 228 |
+
starter.record()
|
| 229 |
+
|
| 230 |
+
for _ in range(self.train_steps): # 1
|
| 231 |
+
|
| 232 |
+
self.step += 1 # self.step starts from 0
|
| 233 |
+
step_ratio = min(1, self.step / self.opt.iters) # 1, step / 500
|
| 234 |
+
|
| 235 |
+
# update lr
|
| 236 |
+
self.renderer.gaussians.update_learning_rate(self.step)
|
| 237 |
+
|
| 238 |
+
loss = 0
|
| 239 |
+
|
| 240 |
+
self.renderer.prepare_render()
|
| 241 |
+
|
| 242 |
+
### known view
|
| 243 |
+
if not self.opt.imagedream:
|
| 244 |
+
for b_idx in range(self.opt.batch_size):
|
| 245 |
+
cur_cam = copy.deepcopy(self.fixed_cam)
|
| 246 |
+
cur_cam.time = b_idx
|
| 247 |
+
out = self.renderer.render(cur_cam)
|
| 248 |
+
|
| 249 |
+
# rgb loss
|
| 250 |
+
image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
| 251 |
+
loss = loss + 10000 * step_ratio * F.mse_loss(image, self.input_img_torch_list[b_idx]) / self.opt.batch_size
|
| 252 |
+
|
| 253 |
+
# mask loss
|
| 254 |
+
mask = out["alpha"].unsqueeze(0) # [1, 1, H, W] in [0, 1]
|
| 255 |
+
loss = loss + 1000 * step_ratio * F.mse_loss(mask, self.input_mask_torch_list[b_idx]) / self.opt.batch_size
|
| 256 |
+
|
| 257 |
+
### novel view (manual batch)
|
| 258 |
+
render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)
|
| 259 |
+
# render_resolution = 512
|
| 260 |
+
images = []
|
| 261 |
+
poses = []
|
| 262 |
+
vers, hors, radii = [], [], []
|
| 263 |
+
# avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
|
| 264 |
+
min_ver = max(min(self.opt.min_ver, self.opt.min_ver - self.opt.elevation), -80 - self.opt.elevation)
|
| 265 |
+
max_ver = min(max(self.opt.max_ver, self.opt.max_ver - self.opt.elevation), 80 - self.opt.elevation)
|
| 266 |
+
|
| 267 |
+
for _ in range(self.opt.n_views):
|
| 268 |
+
for b_idx in range(self.opt.batch_size):
|
| 269 |
+
|
| 270 |
+
# render random view
|
| 271 |
+
ver = np.random.randint(min_ver, max_ver)
|
| 272 |
+
hor = np.random.randint(-180, 180)
|
| 273 |
+
radius = 0
|
| 274 |
+
|
| 275 |
+
vers.append(ver)
|
| 276 |
+
hors.append(hor)
|
| 277 |
+
radii.append(radius)
|
| 278 |
+
|
| 279 |
+
pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
|
| 280 |
+
poses.append(pose)
|
| 281 |
+
|
| 282 |
+
cur_cam = MiniCam(pose, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, time=b_idx)
|
| 283 |
+
|
| 284 |
+
bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device="cuda")
|
| 285 |
+
out = self.renderer.render(cur_cam, bg_color=bg_color)
|
| 286 |
+
|
| 287 |
+
image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
| 288 |
+
images.append(image)
|
| 289 |
+
|
| 290 |
+
# enable mvdream training
|
| 291 |
+
if self.opt.mvdream or self.opt.imagedream: # False
|
| 292 |
+
for view_i in range(1, 4):
|
| 293 |
+
pose_i = orbit_camera(self.opt.elevation + ver, hor + 90 * view_i, self.opt.radius + radius)
|
| 294 |
+
poses.append(pose_i)
|
| 295 |
+
|
| 296 |
+
cur_cam_i = MiniCam(pose_i, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far)
|
| 297 |
+
|
| 298 |
+
# bg_color = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device="cuda")
|
| 299 |
+
out_i = self.renderer.render(cur_cam_i, bg_color=bg_color)
|
| 300 |
+
|
| 301 |
+
image = out_i["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
| 302 |
+
images.append(image)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
images = torch.cat(images, dim=0)
|
| 307 |
+
poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)
|
| 308 |
+
|
| 309 |
+
# guidance loss
|
| 310 |
+
if self.enable_sd:
|
| 311 |
+
if self.opt.mvdream or self.opt.imagedream:
|
| 312 |
+
loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, poses, step_ratio)
|
| 313 |
+
else:
|
| 314 |
+
loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, step_ratio)
|
| 315 |
+
|
| 316 |
+
if self.enable_zero123:
|
| 317 |
+
loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.train_step(images, vers, hors, radii, step_ratio) / (self.opt.batch_size * self.opt.n_views)
|
| 318 |
+
|
| 319 |
+
if self.enable_svd:
|
| 320 |
+
loss = loss + self.opt.lambda_svd * self.guidance_svd.train_step(images, step_ratio)
|
| 321 |
+
|
| 322 |
+
# optimize step
|
| 323 |
+
loss.backward()
|
| 324 |
+
self.optimizer.step()
|
| 325 |
+
self.optimizer.zero_grad()
|
| 326 |
+
|
| 327 |
+
# densify and prune
|
| 328 |
+
if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
|
| 329 |
+
viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]
|
| 330 |
+
self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
| 331 |
+
self.renderer.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
| 332 |
+
|
| 333 |
+
if self.step % self.opt.densification_interval == 0:
|
| 334 |
+
# size_threshold = 20 if self.step > self.opt.opacity_reset_interval else None
|
| 335 |
+
self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=0.5, max_screen_size=1)
|
| 336 |
+
|
| 337 |
+
if self.step % self.opt.opacity_reset_interval == 0:
|
| 338 |
+
self.renderer.gaussians.reset_opacity()
|
| 339 |
+
|
| 340 |
+
ender.record()
|
| 341 |
+
torch.cuda.synchronize()
|
| 342 |
+
t = starter.elapsed_time(ender)
|
| 343 |
+
|
| 344 |
+
self.need_update = True
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def load_input(self, file):
|
| 348 |
+
if self.opt.data_mode == 'c4d':
|
| 349 |
+
file_list = [os.path.join(file, f'{x * self.opt.downsample_rate}.png') for x in range(self.opt.batch_size)]
|
| 350 |
+
elif self.opt.data_mode == 'svd':
|
| 351 |
+
# file_list = [file.replace('.png', f'_frames/{x* self.opt.downsample_rate:03d}_rgba.png') for x in range(self.opt.batch_size)]
|
| 352 |
+
# file_list = [x if os.path.exists(x) else (x.replace('_rgba.png', '.png')) for x in file_list]
|
| 353 |
+
file_list = [file.replace('.png', f'_frames/{x* self.opt.downsample_rate:03d}.png') for x in range(self.opt.batch_size)]
|
| 354 |
+
else:
|
| 355 |
+
raise NotImplementedError
|
| 356 |
+
self.input_img_list, self.input_mask_list = [], []
|
| 357 |
+
for file in file_list:
|
| 358 |
+
# load image
|
| 359 |
+
print(f'[INFO] load image from {file}...')
|
| 360 |
+
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
|
| 361 |
+
if img.shape[-1] == 3:
|
| 362 |
+
if self.bg_remover is None:
|
| 363 |
+
self.bg_remover = rembg.new_session()
|
| 364 |
+
img = rembg.remove(img, session=self.bg_remover)
|
| 365 |
+
# cv2.imwrite(file.replace('.png', '_rgba.png'), img)
|
| 366 |
+
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
|
| 367 |
+
img = img.astype(np.float32) / 255.0
|
| 368 |
+
input_mask = img[..., 3:]
|
| 369 |
+
# white bg
|
| 370 |
+
input_img = img[..., :3] * input_mask + (1 - input_mask)
|
| 371 |
+
# bgr to rgb
|
| 372 |
+
input_img = input_img[..., ::-1].copy()
|
| 373 |
+
self.input_img_list.append(input_img)
|
| 374 |
+
self.input_mask_list.append(input_mask)
|
| 375 |
+
|
| 376 |
+
@torch.no_grad()
|
| 377 |
+
def save_model(self, mode='geo', texture_size=1024, interp=1):
|
| 378 |
+
os.makedirs(self.opt.outdir, exist_ok=True)
|
| 379 |
+
if mode == 'geo':
|
| 380 |
+
path = f'logs/{opt.save_path}_mesh_{t:03d}.ply'
|
| 381 |
+
mesh = self.renderer.gaussians.extract_mesh_t(path, self.opt.density_thresh, t=t)
|
| 382 |
+
mesh.write_ply(path)
|
| 383 |
+
|
| 384 |
+
elif mode == 'geo+tex':
|
| 385 |
+
from mesh import Mesh, safe_normalize
|
| 386 |
+
os.makedirs(os.path.join(self.opt.outdir, self.opt.save_path+'_meshes'), exist_ok=True)
|
| 387 |
+
for t in range(self.opt.batch_size):
|
| 388 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path+'_meshes', f'{t:03d}.obj')
|
| 389 |
+
mesh = self.renderer.gaussians.extract_mesh_t(path, self.opt.density_thresh, t=t)
|
| 390 |
+
|
| 391 |
+
# perform texture extraction
|
| 392 |
+
print(f"[INFO] unwrap uv...")
|
| 393 |
+
h = w = texture_size
|
| 394 |
+
mesh.auto_uv()
|
| 395 |
+
mesh.auto_normal()
|
| 396 |
+
|
| 397 |
+
albedo = torch.zeros((h, w, 3), device=self.device, dtype=torch.float32)
|
| 398 |
+
cnt = torch.zeros((h, w, 1), device=self.device, dtype=torch.float32)
|
| 399 |
+
|
| 400 |
+
vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9]
|
| 401 |
+
hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0]
|
| 402 |
+
|
| 403 |
+
render_resolution = 512
|
| 404 |
+
|
| 405 |
+
import nvdiffrast.torch as dr
|
| 406 |
+
|
| 407 |
+
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
|
| 408 |
+
glctx = dr.RasterizeGLContext()
|
| 409 |
+
else:
|
| 410 |
+
glctx = dr.RasterizeCudaContext()
|
| 411 |
+
|
| 412 |
+
for ver, hor in zip(vers, hors):
|
| 413 |
+
# render image
|
| 414 |
+
pose = orbit_camera(ver, hor, self.cam.radius)
|
| 415 |
+
|
| 416 |
+
cur_cam = MiniCam(
|
| 417 |
+
pose,
|
| 418 |
+
render_resolution,
|
| 419 |
+
render_resolution,
|
| 420 |
+
self.cam.fovy,
|
| 421 |
+
self.cam.fovx,
|
| 422 |
+
self.cam.near,
|
| 423 |
+
self.cam.far,
|
| 424 |
+
time=t
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
cur_out = self.renderer.render(cur_cam)
|
| 428 |
+
|
| 429 |
+
rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
| 430 |
+
|
| 431 |
+
# get coordinate in texture image
|
| 432 |
+
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
|
| 433 |
+
proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(self.device)
|
| 434 |
+
|
| 435 |
+
v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
|
| 436 |
+
v_clip = v_cam @ proj.T
|
| 437 |
+
rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution))
|
| 438 |
+
|
| 439 |
+
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1]
|
| 440 |
+
depth = depth.squeeze(0) # [H, W, 1]
|
| 441 |
+
|
| 442 |
+
alpha = (rast[0, ..., 3:] > 0).float()
|
| 443 |
+
|
| 444 |
+
uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1]
|
| 445 |
+
|
| 446 |
+
# use normal to produce a back-project mask
|
| 447 |
+
normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn)
|
| 448 |
+
normal = safe_normalize(normal[0])
|
| 449 |
+
|
| 450 |
+
# rotated normal (where [0, 0, 1] always faces camera)
|
| 451 |
+
rot_normal = normal @ pose[:3, :3]
|
| 452 |
+
viewcos = rot_normal[..., [2]]
|
| 453 |
+
|
| 454 |
+
mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1]
|
| 455 |
+
mask = mask.view(-1)
|
| 456 |
+
|
| 457 |
+
uvs = uvs.view(-1, 2).clamp(0, 1)[mask]
|
| 458 |
+
rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous()
|
| 459 |
+
|
| 460 |
+
# update texture image
|
| 461 |
+
cur_albedo, cur_cnt = mipmap_linear_grid_put_2d(
|
| 462 |
+
h, w,
|
| 463 |
+
uvs[..., [1, 0]] * 2 - 1,
|
| 464 |
+
rgbs,
|
| 465 |
+
min_resolution=256,
|
| 466 |
+
return_count=True,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
mask = cnt.squeeze(-1) < 0.1
|
| 470 |
+
albedo[mask] += cur_albedo[mask]
|
| 471 |
+
cnt[mask] += cur_cnt[mask]
|
| 472 |
+
|
| 473 |
+
mask = cnt.squeeze(-1) > 0
|
| 474 |
+
albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3)
|
| 475 |
+
|
| 476 |
+
mask = mask.view(h, w)
|
| 477 |
+
|
| 478 |
+
albedo = albedo.detach().cpu().numpy()
|
| 479 |
+
mask = mask.detach().cpu().numpy()
|
| 480 |
+
|
| 481 |
+
# dilate texture
|
| 482 |
+
from sklearn.neighbors import NearestNeighbors
|
| 483 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
| 484 |
+
|
| 485 |
+
inpaint_region = binary_dilation(mask, iterations=32)
|
| 486 |
+
inpaint_region[mask] = 0
|
| 487 |
+
|
| 488 |
+
search_region = mask.copy()
|
| 489 |
+
not_search_region = binary_erosion(search_region, iterations=3)
|
| 490 |
+
search_region[not_search_region] = 0
|
| 491 |
+
|
| 492 |
+
search_coords = np.stack(np.nonzero(search_region), axis=-1)
|
| 493 |
+
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
|
| 494 |
+
|
| 495 |
+
knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(
|
| 496 |
+
search_coords
|
| 497 |
+
)
|
| 498 |
+
_, indices = knn.kneighbors(inpaint_coords)
|
| 499 |
+
|
| 500 |
+
albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)]
|
| 501 |
+
|
| 502 |
+
mesh.albedo = torch.from_numpy(albedo).to(self.device)
|
| 503 |
+
mesh.write(path)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
elif mode == 'frames':
|
| 507 |
+
os.makedirs(os.path.join(self.opt.outdir, self.opt.save_path+'_frames'), exist_ok=True)
|
| 508 |
+
for t in range(self.opt.batch_size * interp):
|
| 509 |
+
tt = t / interp
|
| 510 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path+'_frames', f'{t:03d}.ply')
|
| 511 |
+
self.renderer.gaussians.save_frame_ply(path, tt)
|
| 512 |
+
else:
|
| 513 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path + '_4d_model.ply')
|
| 514 |
+
self.renderer.gaussians.save_ply(path)
|
| 515 |
+
self.renderer.gaussians.save_deformation(self.opt.outdir, self.opt.save_path)
|
| 516 |
+
|
| 517 |
+
print(f"[INFO] save model to {path}.")
|
| 518 |
+
|
| 519 |
+
# no gui mode
|
| 520 |
+
def train(self, iters=500, ui=False):
|
| 521 |
+
if self.gui:
|
| 522 |
+
from visualizer.visergui import ViserViewer
|
| 523 |
+
self.viser_gui = ViserViewer(device="cuda", viewer_port=8080)
|
| 524 |
+
if iters > 0:
|
| 525 |
+
self.prepare_train()
|
| 526 |
+
if self.gui:
|
| 527 |
+
self.viser_gui.set_renderer(self.renderer, self.fixed_cam)
|
| 528 |
+
|
| 529 |
+
for i in tqdm.trange(iters):
|
| 530 |
+
self.train_step()
|
| 531 |
+
if self.gui:
|
| 532 |
+
self.viser_gui.update()
|
| 533 |
+
if self.opt.mesh_format == 'frames':
|
| 534 |
+
self.save_model(mode='frames', interp=4)
|
| 535 |
+
elif self.opt.mesh_format == 'obj':
|
| 536 |
+
self.save_model(mode='geo+tex')
|
| 537 |
+
|
| 538 |
+
if self.opt.save_model:
|
| 539 |
+
self.save_model(mode='model')
|
| 540 |
+
|
| 541 |
+
# render eval
|
| 542 |
+
image_list =[]
|
| 543 |
+
nframes = self.opt.batch_size * 7 + 15 * 7
|
| 544 |
+
hor = 180
|
| 545 |
+
delta_hor = 45 / 15
|
| 546 |
+
delta_time = 1
|
| 547 |
+
for i in range(8):
|
| 548 |
+
time = 0
|
| 549 |
+
for j in range(self.opt.batch_size + 15):
|
| 550 |
+
pose = orbit_camera(self.opt.elevation, hor-180, self.opt.radius)
|
| 551 |
+
cur_cam = MiniCam(
|
| 552 |
+
pose,
|
| 553 |
+
512,
|
| 554 |
+
512,
|
| 555 |
+
self.cam.fovy,
|
| 556 |
+
self.cam.fovx,
|
| 557 |
+
self.cam.near,
|
| 558 |
+
self.cam.far,
|
| 559 |
+
time=time
|
| 560 |
+
)
|
| 561 |
+
with torch.no_grad():
|
| 562 |
+
outputs = self.renderer.render(cur_cam)
|
| 563 |
+
|
| 564 |
+
out = outputs["image"].cpu().detach().numpy().astype(np.float32)
|
| 565 |
+
out = np.transpose(out, (1, 2, 0))
|
| 566 |
+
out = np.uint8(out*255)
|
| 567 |
+
image_list.append(out)
|
| 568 |
+
|
| 569 |
+
time = (time + delta_time) % self.opt.batch_size
|
| 570 |
+
if j >= self.opt.batch_size:
|
| 571 |
+
hor = (hor+delta_hor) % 360
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
imageio.mimwrite(f'vis_data/{opt.save_path}.mp4', image_list, fps=7)
|
| 575 |
+
|
| 576 |
+
if self.gui:
|
| 577 |
+
while True:
|
| 578 |
+
self.viser_gui.update()
|
| 579 |
+
|
| 580 |
+
def process(config, input_path, guidance):
|
| 581 |
+
# override default config from cli
|
| 582 |
+
opt = OmegaConf.load(config)
|
| 583 |
+
opt.input = input_path
|
| 584 |
+
opt.save_path = os.path.splitext(os.path.basename(opt.input))[0] if opt.save_path == '' else opt.save_path
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# auto find mesh from stage 1
|
| 588 |
+
opt.load = os.path.join(opt.outdir, opt.save_path + '_model.ply')
|
| 589 |
+
|
| 590 |
+
gui = GUI(opt, guidance)
|
| 591 |
+
|
| 592 |
+
gui.train(opt.iters)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
if __name__ == "__main__":
|
| 596 |
+
import argparse
|
| 597 |
+
from omegaconf import OmegaConf
|
| 598 |
+
|
| 599 |
+
parser = argparse.ArgumentParser()
|
| 600 |
+
parser.add_argument("--config", required=True, help="path to the yaml config file")
|
| 601 |
+
args, extras = parser.parse_known_args()
|
| 602 |
+
|
| 603 |
+
# override default config from cli
|
| 604 |
+
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
|
| 605 |
+
opt.save_path = os.path.splitext(os.path.basename(opt.input))[0] if opt.save_path == '' else opt.save_path
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
# auto find mesh from stage 1
|
| 609 |
+
opt.load = os.path.join(opt.outdir, opt.save_path + '_model.ply')
|
| 610 |
+
|
| 611 |
+
gui = GUI(opt)
|
| 612 |
+
|
| 613 |
+
gui.train(opt.iters)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
# python main_4d.py --config configs/4d_low.yaml input=data/CONSISTENT4D_DATA/in-the-wild/blooming_rose
|