Spaces:
Runtime error
Runtime error
| import os | |
| import tyro | |
| import glob | |
| import imageio | |
| import numpy as np | |
| import tqdm | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from safetensors.torch import load_file | |
| import rembg | |
| import kiui | |
| from kiui.op import recenter | |
| from kiui.cam import orbit_camera | |
| from core.options import AllConfigs, Options | |
| from core.models import LGM | |
| from mvdream.pipeline_mvdream import MVDreamPipeline | |
| import cv2 | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| opt = tyro.cli(AllConfigs) | |
| # model | |
| model = LGM(opt) | |
| # resume pretrained checkpoint | |
| if opt.resume is not None: | |
| if opt.resume.endswith('safetensors'): | |
| ckpt = load_file(opt.resume, device='cpu') | |
| else: | |
| ckpt = torch.load(opt.resume, map_location='cpu') | |
| model.load_state_dict(ckpt, strict=False) | |
| print(f'[INFO] Loaded checkpoint from {opt.resume}') | |
| else: | |
| print(f'[WARN] model randomly initialized, are you sure?') | |
| # device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.half().to(device) | |
| model.eval() | |
| rays_embeddings = model.prepare_default_rays(device) | |
| tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) | |
| proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) | |
| proj_matrix[0, 0] = 1 / tan_half_fov | |
| proj_matrix[1, 1] = 1 / tan_half_fov | |
| proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) | |
| proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) | |
| proj_matrix[2, 3] = 1 | |
| # load image dream | |
| pipe = MVDreamPipeline.from_pretrained( | |
| "ashawkey/imagedream-ipmv-diffusers", # remote weights | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| # local_files_only=True, | |
| ) | |
| pipe = pipe.to(device) | |
| # load rembg | |
| bg_remover = rembg.new_session() | |
| # process function | |
| def process(opt: Options, path): | |
| name = os.path.splitext(os.path.basename(path))[0] | |
| if 'CONSISTENT4D' in path: | |
| name = path.split('/')[-2] | |
| print(f'[INFO] Processing {path} --> {name}') | |
| os.makedirs('vis_data', exist_ok=True) | |
| os.makedirs('logs', exist_ok=True) | |
| input_image = kiui.read_image(path, mode='uint8') | |
| # bg removal | |
| carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4] | |
| mask = carved_image[..., -1] > 0 | |
| # recenter | |
| image = recenter(carved_image, mask, border_ratio=0.2) | |
| # generate mv | |
| image = image.astype(np.float32) / 255.0 | |
| # rgba to rgb white bg | |
| if image.shape[-1] == 4: | |
| image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) | |
| mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0) | |
| mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32 | |
| # generate gaussians | |
| input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] | |
| input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) | |
| input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |
| input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W] | |
| with torch.no_grad(): | |
| ############## align azimuth ##################### | |
| with torch.autocast(device_type='cuda', dtype=torch.float16): | |
| # generate gaussians | |
| gaussians = model.forward_gaussians(input_image) | |
| best_azi = 0 | |
| best_diff = 1e8 | |
| for v, azi in enumerate(np.arange(-180, 180, 1)): | |
| cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) | |
| cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction | |
| # cameras needed by gaussian rasterizer | |
| cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] | |
| cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] | |
| cam_pos = - cam_poses[:, :3, 3] # [V, 3] | |
| # scale = min(azi / 360, 1) | |
| scale = 1 | |
| result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale) | |
| rendered_image = result['image'] | |
| rendered_image = rendered_image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy() | |
| rendered_image = cv2.resize(rendered_image, (image.shape[0], image.shape[1]), interpolation=cv2.INTER_AREA) | |
| diff = np.mean((rendered_image- image) ** 2) | |
| if diff < best_diff: | |
| best_diff = diff | |
| best_azi = azi | |
| print("Best aligned azimuth: ", best_azi) | |
| mv_image = [] | |
| for v, azi in enumerate([0, 90, 180, 270]): | |
| cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) | |
| cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction | |
| # cameras needed by gaussian rasterizer | |
| cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] | |
| cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] | |
| cam_pos = - cam_poses[:, :3, 3] # [V, 3] | |
| # scale = min(azi / 360, 1) | |
| scale = 1 | |
| result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale) | |
| rendered_image = result['image'] | |
| rendered_image = rendered_image.squeeze(1) | |
| rendered_image = F.interpolate(rendered_image, (256, 256)) | |
| rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy() | |
| mv_image.append(rendered_image) | |
| mv_image = np.concatenate(mv_image, axis=0) | |
| input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] | |
| input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) | |
| input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |
| input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W] | |
| ################################ | |
| with torch.autocast(device_type='cuda', dtype=torch.float16): | |
| # generate gaussians | |
| gaussians = model.forward_gaussians(input_image) | |
| # save gaussians | |
| model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply')) | |
| # render 360 video | |
| images = [] | |
| elevation = 0 | |
| if opt.fancy_video: | |
| azimuth = np.arange(0, 720, 4, dtype=np.int32) | |
| for azi in tqdm.tqdm(azimuth): | |
| cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) | |
| cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction | |
| # cameras needed by gaussian rasterizer | |
| cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] | |
| cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] | |
| cam_pos = - cam_poses[:, :3, 3] # [V, 3] | |
| scale = min(azi / 360, 1) | |
| image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image'] | |
| images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) | |
| else: | |
| azimuth = np.arange(0, 360, 2, dtype=np.int32) | |
| for azi in tqdm.tqdm(azimuth): | |
| cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) | |
| cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction | |
| # cameras needed by gaussian rasterizer | |
| cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] | |
| cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] | |
| cam_pos = - cam_poses[:, :3, 3] # [V, 3] | |
| image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image'] | |
| images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) | |
| images = np.concatenate(images, axis=0) | |
| imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30) | |
| assert opt.test_path is not None | |
| if os.path.isdir(opt.test_path): | |
| file_paths = glob.glob(os.path.join(opt.test_path, "*")) | |
| else: | |
| file_paths = [opt.test_path] | |
| for path in file_paths: | |
| process(opt, path) |