Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| import argparse | |
| import os | |
| import einops | |
| import numpy as np | |
| import torch | |
| import torch.utils.checkpoint | |
| from accelerate.utils import ProjectConfiguration, set_seed | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm.auto import tqdm | |
| import torchvision | |
| import json | |
| import cv2 | |
| from skimage.io import imsave | |
| import matplotlib.pyplot as plt | |
| # read .exr files for RTMV dataset | |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
| def parse_args(input_args=None): | |
| parser = argparse.ArgumentParser(description="Simple example of a Zero123 training script.") | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default="lambdalabs/sd-image-variations-diffusers", | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--revision", | |
| type=str, | |
| default=None, | |
| required=False, | |
| help=( | |
| "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" | |
| " float32 precision." | |
| ), | |
| ) | |
| parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") | |
| parser.add_argument( | |
| "--resolution", | |
| type=int, | |
| default=256, | |
| help=( | |
| "The resolution for input images, all the images in the train/validation dataset will be resized to this" | |
| " resolution" | |
| ), | |
| ) | |
| parser.add_argument("--num_train_epochs", type=int, default=1) | |
| parser.add_argument( | |
| "--T_in", type=int, default=1, help="Number of input views" | |
| ) | |
| parser.add_argument( | |
| "--T_out", type=int, default=1, help="Number of output views" | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale", | |
| type=float, | |
| default=3.0, | |
| help="unconditional guidance scale, if guidance_scale>1.0, do_classifier_free_guidance" | |
| ) | |
| parser.add_argument( | |
| "--data_dir", | |
| type=str, | |
| default=".", | |
| help=( | |
| "The input data dir. Should contain the .png files (or other data files) for the task." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--data_type", | |
| type=str, | |
| default="GSO25", | |
| help=( | |
| "The input data type. Chosen from GSO25, GSO3D, GSO100, RTMV, NeRF, Franka, MVDream, Text2Img" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--cape_type", | |
| type=str, | |
| default="6DoF", | |
| help=( | |
| "The camera pose encoding CaPE type. Chosen from 4DoF, 6DoF" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="logs_eval", | |
| help=( | |
| "The output directory where the model predictions and checkpoints will be written." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str, | |
| default=None, | |
| choices=["no", "fp16", "bf16"], | |
| help=( | |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | |
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" | |
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--enable_xformers_memory_efficient_attention", default=True, help="Whether or not to use xformers." | |
| ) | |
| if input_args is not None: | |
| args = parser.parse_args(input_args) | |
| else: | |
| args = parser.parse_args() | |
| if args.resolution % 8 != 0: | |
| raise ValueError( | |
| "`--resolution` must be divisible by 8 for consistently sized encoded images." | |
| ) | |
| return args | |
| # create angles in archimedean spiral with T_out number | |
| import math | |
| def get_archimedean_spiral(sphere_radius, num_steps=250): | |
| # x-z plane, around upper y | |
| ''' | |
| https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi | |
| ''' | |
| a = 40 | |
| r = sphere_radius | |
| translations = [] | |
| angles = [] | |
| # i = a / 2 | |
| i = 0.01 | |
| while i < a: | |
| theta = i / a * math.pi | |
| x = r * math.sin(theta) * math.cos(-i) | |
| z = r * math.sin(-theta + math.pi) * math.sin(-i) | |
| y = r * - math.cos(theta) | |
| # translations.append((x, y, z)) # origin | |
| translations.append((x, z, -y)) | |
| angles.append([np.rad2deg(-i), np.rad2deg(theta)]) | |
| # i += a / (2 * num_steps) | |
| i += a / (1 * num_steps) | |
| return np.array(translations), np.stack(angles) | |
| # 36 views around the circle, with elevation degree | |
| def get_circle_traj(sphere_radius, elevation=0, num_steps=36): | |
| translations = [] | |
| angles = [] | |
| elevation = np.deg2rad(elevation) | |
| for i in range(num_steps): | |
| theta = i / num_steps * 2 * math.pi | |
| x = sphere_radius * math.sin(theta) * math.cos(elevation) | |
| z = sphere_radius * math.sin(-theta+math.pi) * math.sin(-elevation) | |
| y = sphere_radius * -math.cos(theta) | |
| translations.append((x, z, -y)) | |
| angles.append([np.rad2deg(-elevation), np.rad2deg(theta)]) | |
| return np.array(translations), np.stack(angles) | |
| def look_at(origin, target, up): | |
| forward = (target - origin) | |
| forward = forward / np.linalg.norm(forward) | |
| right = np.cross(up, forward) | |
| right = right / np.linalg.norm(right) | |
| new_up = np.cross(forward, right) | |
| rotation_matrix = np.column_stack((right, new_up, -forward, target)) | |
| matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1])) | |
| return matrix | |
| # from carvekit.api.high import HiInterface | |
| # def create_carvekit_interface(): | |
| # # Check doc strings for more information | |
| # interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". | |
| # batch_size_seg=5, | |
| # batch_size_matting=1, | |
| # device='cuda' if torch.cuda.is_available() else 'cpu', | |
| # seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net | |
| # matting_mask_size=2048, | |
| # trimap_prob_threshold=231, | |
| # trimap_dilation=30, | |
| # trimap_erosion_iters=5, | |
| # fp16=False) | |
| # | |
| # return interface | |
| import rembg | |
| def create_rembg_interface(): | |
| rembg_session = rembg.new_session() | |
| return rembg_session | |
| def main(args): | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| CaPE_TYPE = args.cape_type | |
| if CaPE_TYPE == "6DoF": | |
| import sys | |
| sys.path.insert(0, "./6DoF/") | |
| # use the customized diffusers modules | |
| from diffusers import DDIMScheduler | |
| from dataset import get_pose | |
| from CN_encoder import CN_encoder | |
| from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline | |
| elif CaPE_TYPE == "4DoF": | |
| import sys | |
| sys.path.insert(0, "./4DoF/") | |
| # use the customized diffusers modules | |
| from diffusers import DDIMScheduler | |
| from dataset import get_pose | |
| from CN_encoder import CN_encoder | |
| from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline | |
| else: | |
| raise ValueError("CaPE_TYPE must be chosen from 4DoF, 6DoF") | |
| # from dataset import get_pose | |
| # from CN_encoder import CN_encoder | |
| # from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline | |
| DATA_DIR = args.data_dir | |
| DATA_TYPE = args.data_type | |
| if DATA_TYPE == "GSO25": | |
| T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO | |
| T_out_DATA_TYPE = "render_mvs_25" # for 2D metrics | |
| T_out = 25 | |
| elif DATA_TYPE == "GSO25_6dof": | |
| T_in_DATA_TYPE = "render_6dof_25" # same condition for GSO | |
| T_out_DATA_TYPE = "render_6dof_25" # for 2D metrics | |
| T_out = 25 | |
| elif DATA_TYPE == "GSO3D": | |
| T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO | |
| T_out_DATA_TYPE = "render_sync_36_single" # for 3D metrics | |
| T_out = 36 | |
| elif DATA_TYPE == "GSO100": | |
| T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO | |
| T_out_DATA_TYPE = "render_spiral_100" # for 360 gif | |
| T_out = 100 | |
| elif DATA_TYPE == "NeRF": | |
| T_out = 200 | |
| elif DATA_TYPE == "RTMV": | |
| T_out = 20 | |
| elif DATA_TYPE == "Franka": | |
| T_out = 100 # do a 360 gif | |
| elif DATA_TYPE == "MVDream": | |
| T_out = 100 # do a 360 gif | |
| elif DATA_TYPE == "Text2Img": | |
| T_out = 100 # do a 360 gif | |
| elif DATA_TYPE == "dust3r": | |
| # carvekit = create_carvekit_interface() | |
| rembg_session = create_rembg_interface() | |
| T_out = 50 # do a 360 gif | |
| # get the number of .png files in the folder | |
| obj_names = [f for f in os.listdir(DATA_DIR+"/user_object") if f.endswith('.png')] | |
| args.T_in = len(obj_names) | |
| else: | |
| raise NotImplementedError | |
| T_in = args.T_in | |
| OUTPUT_DIR= f"logs_{CaPE_TYPE}/{DATA_TYPE}/N{T_in}M{T_out}" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # get all folders in DATA_DIR | |
| if DATA_TYPE == "Text2Img": | |
| # get all rgba_png in DATA_DIR | |
| obj_names = [f for f in os.listdir(DATA_DIR) if f.endswith('rgba.png')] | |
| else: | |
| obj_names = [f for f in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, f))] | |
| weight_dtype = torch.float16 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| h, w = args.resolution, args.resolution | |
| bg_color = [1., 1., 1., 1.] | |
| radius = 2.2 #1.5 #1.8 # Objaverse training radius [1.5, 2.2] | |
| # radius_4dof = np.pi * (np.log(radius) - np.log(1.5)) / (np.log(2.2)-np.log(1.5)) | |
| # Init Dataset | |
| image_transforms = torchvision.transforms.Compose( | |
| [ | |
| torchvision.transforms.Resize((args.resolution, args.resolution)), # 256, 256 | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ] | |
| ) | |
| # Init pipeline | |
| scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", | |
| revision=args.revision) | |
| image_encoder = CN_encoder.from_pretrained(args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision) | |
| pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| revision=args.revision, | |
| scheduler=scheduler, | |
| image_encoder=None, | |
| safety_checker=None, | |
| feature_extractor=None, | |
| torch_dtype=weight_dtype, | |
| ) | |
| pipeline.image_encoder = image_encoder | |
| pipeline = pipeline.to(device) | |
| pipeline.set_progress_bar_config(disable=False) | |
| if args.enable_xformers_memory_efficient_attention: | |
| pipeline.enable_xformers_memory_efficient_attention() | |
| # enable vae slicing | |
| pipeline.enable_vae_slicing() | |
| if args.seed is None: | |
| generator = None | |
| else: | |
| generator = torch.Generator(device=device).manual_seed(args.seed) | |
| for obj_name in tqdm(obj_names): | |
| print(f"Processing {obj_name}") | |
| if DATA_TYPE == "NeRF": | |
| if os.path.exists(os.path.join(args.output_dir, obj_name, "output.gif")): | |
| continue | |
| # load train info | |
| with open(os.path.join(DATA_DIR, obj_name, "transforms_train.json"), "r") as f: | |
| train_info = json.load(f)["frames"] | |
| # load test info | |
| with open(os.path.join(DATA_DIR, obj_name, "transforms_test.json"), "r") as f: | |
| test_info = json.load(f)["frames"] | |
| # find the radius [min_t, max_t] of the object, we later scale it to training radius [1.5, 2.2] | |
| max_t = 0 | |
| min_t = 100 | |
| for i in range(len(train_info)): | |
| pose = np.array(train_info[i]["transform_matrix"]).reshape(4, 4) | |
| translation = pose[:3, -1] | |
| radii = np.linalg.norm(translation) | |
| if max_t < radii: | |
| max_t = radii | |
| if min_t > radii: | |
| min_t = radii | |
| info_dir = os.path.join("metrics/NeRF_idx", obj_name) | |
| assert os.path.exists(info_dir) # use fixed train index | |
| train_index = np.load(os.path.join(info_dir, f"train_N{T_in}M20_random.npy")) | |
| test_index = np.arange(len(test_info)) # use all test views | |
| elif DATA_TYPE == "Franka": | |
| angles_in = np.load(os.path.join(DATA_DIR, obj_name, "angles.npy")) # azimuth, elevation in radians | |
| assert T_in <= len(angles_in) | |
| total_index = np.arange(0, len(angles_in)) # num of input views | |
| # random shuffle total_index | |
| np.random.shuffle(total_index) | |
| train_index = total_index[:T_in] | |
| xyzs, angles_out = get_archimedean_spiral(radius, T_out) | |
| origin = np.array([0, 0, 0]) | |
| up = np.array([0, 0, 1]) | |
| test_index = np.arange(len(angles_out)) # use all 100 test views | |
| elif DATA_TYPE == "MVDream": # 4 input views front right back left | |
| angles_in = [] | |
| for polar in [90]: # 1 | |
| for azimu in np.arange(0, 360, 90): # 4 | |
| angles_in.append(np.array([azimu, polar])) | |
| assert T_in == len(angles_in) | |
| xyzs, angles_out = get_archimedean_spiral(radius, T_out) | |
| origin = np.array([0, 0, 0]) | |
| up = np.array([0, 0, 1]) | |
| train_index = np.arange(T_in) | |
| test_index = np.arange(T_out) | |
| elif DATA_TYPE == "Text2Img": # 1 input view | |
| angles_in = [] | |
| angles_in.append(np.array([0, 90])) | |
| assert T_in == len(angles_in) | |
| xyzs, angles_out = get_archimedean_spiral(radius, T_out) | |
| origin = np.array([0, 0, 0]) | |
| up = np.array([0, 0, 1]) | |
| train_index = np.arange(T_in) | |
| test_index = np.arange(T_out) | |
| elif DATA_TYPE == "dust3r": | |
| # TODO full archimedean spiral traj | |
| # xyzs, angles_out = get_archimedean_spiral(radius, T_out) | |
| # TODO only top circle traj | |
| xyzs, angles_out = get_archimedean_spiral(1.5, 100) | |
| xyzs = xyzs[:T_out] | |
| angles_out = angles_out[:T_out] | |
| # # TODO circle traj | |
| # xyzs, angles_out = get_circle_traj(radius, elevation=30, num_steps=T_out) | |
| origin = np.array([0, 0, 0]) | |
| up = np.array([0, 0, 1]) | |
| train_index = np.arange(T_in) | |
| test_index = np.arange(T_out) | |
| # get the max_t | |
| radii = np.load(os.path.join(DATA_DIR, obj_name, "radii.npy")) | |
| max_t = np.max(radii) | |
| min_t = np.min(radii) | |
| else: | |
| train_index = np.arange(T_in) | |
| test_index = np.arange(T_out) | |
| # prepare input img + pose, output pose | |
| input_image = [] | |
| pose_in = [] | |
| pose_out = [] | |
| gt_image = [] | |
| for T_in_index in train_index: | |
| if DATA_TYPE == "RTMV": | |
| img_path = os.path.join(DATA_DIR, obj_name, '%05d.exr' % T_in_index) | |
| input_im = cv2.imread(img_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) | |
| img = cv2.cvtColor(input_im, cv2.COLOR_BGR2RGB, input_im) | |
| img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") | |
| input_image.append(image_transforms(img)) | |
| # load input pose | |
| pose_path = os.path.join(DATA_DIR, obj_name, '%05d.json' % T_in_index) | |
| with open(pose_path, "r") as f: | |
| pose_dict = json.load(f) | |
| input_RT = np.array(pose_dict["camera_data"]["cam2world"]).T | |
| input_RT = np.linalg.inv(input_RT)[:3] | |
| pose_in.append(get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
| else: | |
| if DATA_TYPE == "NeRF": | |
| img_path = os.path.join(DATA_DIR, obj_name, train_info[T_in_index]["file_path"] + ".png") | |
| pose = np.array(train_info[T_in_index]["transform_matrix"]) | |
| if CaPE_TYPE == "6DoF": | |
| # blender to opencv | |
| pose[1:3, :] *= -1 | |
| pose = np.linalg.inv(pose) | |
| # scale radius to [1.5, 2.2] | |
| pose[:3, 3] *= 1. / max_t * radius | |
| elif CaPE_TYPE == "4DoF": | |
| pose = np.linalg.inv(pose) | |
| pose_in.append(torch.from_numpy(get_pose(pose))) | |
| elif DATA_TYPE == "Franka": | |
| img_path = os.path.join(DATA_DIR, obj_name, "images_rgba", f"frame{T_in_index:06d}.png") | |
| azimuth, elevation = np.rad2deg(angles_in[T_in_index]) | |
| print("input angles index", T_in_index, "azimuth", azimuth, "elevation", elevation) | |
| if CaPE_TYPE == "4DoF": | |
| pose_in.append(torch.from_numpy([np.deg2rad(90. - elevation), np.deg2rad(azimuth - 180), 0., 0.])) | |
| elif CaPE_TYPE == "6DoF": | |
| neg_i = np.deg2rad(azimuth - 180) | |
| neg_theta = np.deg2rad(90. - elevation) | |
| xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), | |
| np.sin(-neg_theta + np.pi) * np.sin(neg_i), | |
| np.cos(neg_theta)]) * radius | |
| pose = look_at(origin, xyz, up) | |
| pose = np.linalg.inv(pose) | |
| pose[2, :] *= -1 | |
| pose_in.append(torch.from_numpy(get_pose(pose))) | |
| elif DATA_TYPE == "MVDream" or DATA_TYPE == "Text2Img": | |
| if DATA_TYPE == "MVDream": | |
| img_path = os.path.join(DATA_DIR, obj_name, f"{T_in_index}_rgba.png") | |
| elif DATA_TYPE == "Text2Img": | |
| img_path = os.path.join(DATA_DIR, obj_name) | |
| azimuth, polar = angles_in[T_in_index] | |
| if CaPE_TYPE == "4DoF": | |
| pose_in.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) | |
| elif CaPE_TYPE == "6DoF": | |
| neg_theta = np.deg2rad(polar) | |
| neg_i = np.deg2rad(azimuth) | |
| xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), | |
| np.sin(-neg_theta + np.pi) * np.sin(neg_i), | |
| np.cos(neg_theta)]) * radius | |
| pose = look_at(origin, xyz, up) | |
| pose = np.linalg.inv(pose) | |
| pose[2, :] *= -1 | |
| pose_in.append(torch.from_numpy(get_pose(pose))) | |
| elif DATA_TYPE == "dust3r": # TODO get the object coordinate, now one of the camera is the center | |
| img_path = os.path.join(DATA_DIR, obj_name, "%03d.png" % T_in_index) | |
| pose = get_pose(np.linalg.inv(np.load(os.path.join(DATA_DIR, obj_name, "%03d.npy" % T_in_index)))) | |
| pose[1:3, :] *= -1 | |
| # scale radius to [1.5, 2.2] | |
| pose[:3, 3] *= 1. / max_t * radius | |
| pose_in.append(torch.from_numpy(pose)) | |
| else: # GSO | |
| img_path = os.path.join(DATA_DIR, obj_name, T_in_DATA_TYPE, "model/%03d.png" % T_in_index) | |
| pose_path = os.path.join(DATA_DIR, obj_name, T_in_DATA_TYPE, "model/%03d.npy" % T_in_index) | |
| if T_in_DATA_TYPE == "render_mvs_25" or T_in_DATA_TYPE == "render_6dof_25": # blender coordinate | |
| pose_in.append(get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
| else: # opencv coordinate | |
| pose = get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0)) | |
| pose[1:3, :] *= -1 # pose out 36 is in opencv coordinate, pose in 25 is in blender coordinate | |
| pose_in.append(torch.from_numpy(pose)) | |
| # pose_in.append(get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
| # load image | |
| img = plt.imread(img_path) | |
| if (img.shape[-1] == 3 or (img[:,:,-1] == 1).all()) and DATA_TYPE == "dust3r": | |
| img_pil = Image.fromarray(np.uint8(img * 255.)).convert("RGB") # to PIL image | |
| ## use carvekit | |
| # image_without_background = carvekit([img_pil])[0] | |
| # image_without_background = np.array(image_without_background) | |
| # est_seg = image_without_background > 127 | |
| # foreground = est_seg[:, :, -1].astype(np.bool_) | |
| # img = np.concatenate([img[:,:,:3], foreground[:, :, np.newaxis]], axis=-1) | |
| # use rembg | |
| image = rembg.remove(img_pil, session=rembg_session) | |
| foreground = np.array(image)[:,:,-1] > 127 | |
| img = np.concatenate([img[:,:,:3], foreground[:, :, np.newaxis]], axis=-1) | |
| img[img[:, :, -1] == 0.] = bg_color | |
| img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") | |
| input_image.append(image_transforms(img)) | |
| for T_out_index in test_index: | |
| if DATA_TYPE == "RTMV": | |
| img_path = os.path.join(DATA_DIR, obj_name, '%05d.exr' % T_out_index) | |
| gt_im = cv2.imread(img_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) | |
| img = cv2.cvtColor(gt_im, cv2.COLOR_BGR2RGB, gt_im) | |
| img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") | |
| gt_image.append(image_transforms(img)) | |
| # load pose | |
| pose_path = os.path.join(DATA_DIR, obj_name, '%05d.json' % T_out_index) | |
| with open(pose_path, "r") as f: | |
| pose_dict = json.load(f) | |
| output_RT = np.array(pose_dict["camera_data"]["cam2world"]).T | |
| output_RT = np.linalg.inv(output_RT)[:3] | |
| pose_out.append(get_pose(np.concatenate([output_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
| else: | |
| if DATA_TYPE == "NeRF": | |
| img_path = os.path.join(DATA_DIR, obj_name, test_info[T_out_index]["file_path"] + ".png") | |
| pose = np.array(test_info[T_out_index]["transform_matrix"]) | |
| if CaPE_TYPE == "6DoF": | |
| # blender to opencv | |
| pose[1:3, :] *= -1 | |
| pose = np.linalg.inv(pose) | |
| # scale radius to [1.5, 2.2] | |
| pose[:3, 3] *= 1. / max_t * radius | |
| elif CaPE_TYPE == "4DoF": | |
| pose = np.linalg.inv(pose) | |
| pose_out.append(torch.from_numpy(get_pose(pose))) | |
| elif DATA_TYPE == "Franka": | |
| img_path = None | |
| azimuth, polar = angles_out[T_out_index] | |
| if CaPE_TYPE == "4DoF": | |
| pose_out.append(torch.from_numpy([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) | |
| elif CaPE_TYPE == "6DoF": | |
| pose = look_at(origin, xyzs[T_out_index], up) | |
| neg_theta = np.deg2rad(polar) | |
| neg_i = np.deg2rad(azimuth) | |
| xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), | |
| np.sin(-neg_theta + np.pi) * np.sin(neg_i), | |
| np.cos(neg_theta)]) * radius | |
| assert np.allclose(xyzs[T_out_index], xyz) | |
| pose = np.linalg.inv(pose) | |
| pose[2, :] *= -1 | |
| pose_out.append(torch.from_numpy(get_pose(pose))) | |
| elif DATA_TYPE == "MVDream" or DATA_TYPE == "Text2Img" or DATA_TYPE == "dust3r": | |
| img_path = None | |
| azimuth, polar = angles_out[T_out_index] | |
| if CaPE_TYPE == "4DoF": | |
| pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) | |
| elif CaPE_TYPE == "6DoF": | |
| pose = look_at(origin, xyzs[T_out_index], up) | |
| pose = np.linalg.inv(pose) | |
| pose[2, :] *= -1 | |
| pose_out.append(torch.from_numpy(get_pose(pose))) | |
| else: # GSO | |
| img_path = os.path.join(DATA_DIR, obj_name, T_out_DATA_TYPE, "model/%03d.png" % T_out_index) | |
| pose_path = os.path.join(DATA_DIR, obj_name, T_out_DATA_TYPE, "model/%03d.npy" % T_out_index) | |
| if T_out_DATA_TYPE == "render_mvs_25" or T_out_DATA_TYPE == "render_6dof_25": # blender coordinate | |
| pose_out.append(get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
| else: # opencv coordinate | |
| pose = get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0)) | |
| pose[1:3, :] *= -1 # pose out 36 is in opencv coordinate, pose in 25 is in blender coordinate | |
| pose_out.append(torch.from_numpy(pose)) | |
| # load image | |
| if img_path is not None: # sometimes don't have GT target view image | |
| img = plt.imread(img_path) | |
| img[img[:, :, -1] == 0.] = bg_color | |
| img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") | |
| gt_image.append(image_transforms(img)) | |
| # [B, T, C, H, W] | |
| input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0) | |
| if len(gt_image)>0: | |
| gt_image = torch.stack(gt_image, dim=0).to(device).to(weight_dtype).unsqueeze(0) | |
| # [B, T, 4] | |
| pose_in = np.stack(pose_in) | |
| pose_out = np.stack(pose_out) | |
| if CaPE_TYPE == "6DoF": | |
| pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1]) | |
| pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1]) | |
| pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0) | |
| pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0) | |
| pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0) | |
| pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0) | |
| input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w") | |
| if len(gt_image)>0: | |
| gt_image = einops.rearrange(gt_image, "b t c h w -> (b t) c h w") | |
| assert T_in == input_image.shape[0] | |
| assert T_in == pose_in.shape[1] | |
| assert T_out == pose_out.shape[1] | |
| # run inference | |
| if CaPE_TYPE == "6DoF": | |
| with torch.autocast("cuda"): | |
| image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]], | |
| height=h, width=w, T_in=T_in, T_out=T_out, | |
| guidance_scale=args.guidance_scale, num_inference_steps=50, generator=generator, | |
| output_type="numpy").images | |
| elif CaPE_TYPE == "4DoF": | |
| with torch.autocast("cuda"): | |
| image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in], | |
| height=h, width=w, T_in=T_in, T_out=T_out, | |
| guidance_scale=args.guidance_scale, num_inference_steps=50, generator=generator, | |
| output_type="numpy").images | |
| # save results | |
| output_dir = os.path.join(OUTPUT_DIR, obj_name) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # save input image for visualization | |
| imsave(os.path.join(output_dir, 'input.png'), | |
| ((np.concatenate(input_image.permute(0, 2, 3, 1).cpu().numpy(), 1) + 1) / 2 * 255).astype(np.uint8)) | |
| # save output image | |
| if T_out >= 30: | |
| # save to N imgs | |
| for i in range(T_out): | |
| imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8)) | |
| # make a gif | |
| frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)] | |
| frame_one = frames[0] | |
| frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames, | |
| save_all=True, duration=50, loop=1) | |
| else: | |
| imsave(os.path.join(output_dir, '0.png'), (np.concatenate(image, 1) * 255).astype(np.uint8)) | |
| # save gt for visualization | |
| if len(gt_image)>0: | |
| imsave(os.path.join(output_dir, 'gt.png'), | |
| ((np.concatenate(gt_image.permute(0, 2, 3, 1).cpu().numpy(), 1) + 1) / 2 * 255).astype(np.uint8)) | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |