Spaces:
Running
Running
| """ | |
| inference.py | |
| ------------ | |
| Provides functionality to run the OPDMulti model on an input image, independent of dataset and ground truth, and | |
| visualize the output. Large portions of the code originate from get_prediction.py, rgbd_to_pcd_vis.py, | |
| evaluate_on_log.py, and other related files. The primary goal was to create a more standalone script which could be | |
| converted more easily into a public demo, thus the goal was to sever most dependencies on existing ground truth or | |
| datasets. | |
| Example usage: | |
| python inference.py \ | |
| --rgb path/to/59-4860.png \ | |
| --depth path/to/59-4860_d.png \ | |
| --model path/to/model.pth \ | |
| --output path/to/output_dir | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| import time | |
| from typing import Any | |
| import imageio | |
| import open3d as o3d | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from detectron2 import engine, evaluation | |
| from detectron2.modeling import build_model | |
| from detectron2.config import get_cfg, CfgNode | |
| from detectron2.projects.deeplab import add_deeplab_config | |
| from detectron2.structures import instances | |
| from detectron2.utils import comm | |
| from detectron2.utils.logger import setup_logger | |
| from mask2former import ( | |
| add_maskformer2_config, | |
| add_motionnet_config, | |
| ) | |
| from utilities import prediction_to_json | |
| from visualization import ( | |
| draw_line, | |
| generate_rotation_visualization, | |
| generate_translation_visualization, | |
| batch_trim, | |
| ) | |
| # import based on torch version. Required for model loading. Code is taken from fvcore.common.checkpoint, in order to | |
| # replicate model loading without the overhead of setting up an OPDTrainer | |
| TORCH_VERSION: tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) | |
| if TORCH_VERSION >= (1, 11): | |
| from torch.ao import quantization | |
| from torch.ao.quantization import FakeQuantizeBase, ObserverBase | |
| elif ( | |
| TORCH_VERSION >= (1, 8) | |
| and hasattr(torch.quantization, "FakeQuantizeBase") | |
| and hasattr(torch.quantization, "ObserverBase") | |
| ): | |
| from torch import quantization | |
| from torch.quantization import FakeQuantizeBase, ObserverBase | |
| # TODO: find a global place for this instead of in many places in code | |
| TYPE_CLASSIFICATION = { | |
| 0: "rotation", | |
| 1: "translation", | |
| } | |
| ARROW_COLOR = [0, 1, 0] # green | |
| def get_parser() -> argparse.ArgumentParser: | |
| """ | |
| Specfy command-line arguments. | |
| The primary inputs to the script should be the image paths (RGBD) and camera intrinsics. Other arguments are | |
| provided to facilitate script testing and model changes. Run file with -h/--help to see all arguments. | |
| :return: parser for extracting command-line arguments | |
| """ | |
| parser = argparse.ArgumentParser(description="Inference for OPDMulti") | |
| # The main arguments which should be specified by the user | |
| parser.add_argument( | |
| "--rgb", | |
| dest="rgb_image", | |
| metavar="FILE", | |
| help="path to RGB image file on which to run model", | |
| ) | |
| parser.add_argument( | |
| "--depth", | |
| dest="depth_image", | |
| metavar="FILE", | |
| help="path to depth image file on which to run model", | |
| ) | |
| parser.add_argument( # FIXME: might make more sense to make this a path | |
| "-i", | |
| "--intrinsics", | |
| nargs=9, | |
| default=[ | |
| 214.85935872395834, | |
| 0.0, | |
| 0.0, | |
| 0.0, | |
| 214.85935872395834, | |
| 0.0, | |
| 125.90160319010417, | |
| 95.13726399739583, | |
| 1.0, | |
| ], | |
| dest="intrinsics", | |
| help="camera intrinsics matrix, as a list of values", | |
| ) | |
| # optional parameters for user to specify | |
| parser.add_argument( | |
| "-n", | |
| "--num-samples", | |
| default=10, | |
| dest="num_samples", | |
| metavar="NUM", | |
| help="number of sample states to generate in visualization", | |
| ) | |
| parser.add_argument( | |
| "--crop", | |
| action="store_true", | |
| dest="crop", | |
| help="crop whitespace out of images for visualization", | |
| ) | |
| # local script development arguments | |
| parser.add_argument( | |
| "-m", | |
| "--model", | |
| default="path/to/model/file", # FIXME: set a good default path | |
| dest="model", | |
| metavar="FILE", | |
| help="path to model file to run", | |
| ) | |
| parser.add_argument( | |
| "-c", | |
| "--config", | |
| default="configs/coco/instance-segmentation/swin/opd_v1_real.yaml", | |
| metavar="FILE", | |
| dest="config_file", | |
| help="path to config file", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| default="output", # FIXME: set a good default path | |
| dest="output", | |
| help="path to output directory in which to save results", | |
| ) | |
| parser.add_argument( | |
| "--num-processes", | |
| default=1, | |
| dest="num_processes", | |
| help="number of processes per machine. When using GPUs, this should be the number of GPUs.", | |
| ) | |
| parser.add_argument( | |
| "-s", | |
| "--score-threshold", | |
| default=0.8, | |
| type=float, | |
| dest="score_threshold", | |
| help="threshold between 0.0 and 1.0 by which to filter out bad predictions", | |
| ) | |
| parser.add_argument( | |
| "--input-format", | |
| default="RGB", | |
| dest="input_format", | |
| help="input format of image. Must be one of RGB, RGBD, or depth", | |
| ) | |
| parser.add_argument( | |
| "--cpu", | |
| action="store_true", | |
| help="flag to require code to use CPU only", | |
| ) | |
| return parser | |
| def setup_cfg(args: argparse.Namespace) -> CfgNode: | |
| """ | |
| Create configs and perform basic setups. | |
| """ | |
| cfg = get_cfg() | |
| # add model configurations | |
| add_deeplab_config(cfg) | |
| add_maskformer2_config(cfg) | |
| add_motionnet_config(cfg) | |
| cfg.merge_from_file(args.config_file) | |
| # set additional config parameters | |
| cfg.MODEL.WEIGHTS = args.model | |
| cfg.OBJ_DETECT = False # TODO: figure out if this is needed, and parameterize it | |
| cfg.MODEL.MOTIONNET.VOTING = "none" | |
| # Output directory | |
| cfg.OUTPUT_DIR = args.output | |
| cfg.MODEL.DEVICE = "cpu" if args.cpu else "cuda" | |
| cfg.MODEL.MODELATTRPATH = None | |
| # Input format | |
| cfg.INPUT.FORMAT = args.input_format | |
| if args.input_format == "RGB": | |
| cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN[0:3] | |
| cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD[0:3] | |
| elif args.input_format == "depth": | |
| cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN[3:4] | |
| cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD[3:4] | |
| elif args.input_format == "RGBD": | |
| pass | |
| else: | |
| raise ValueError("Invalid input format") | |
| cfg.freeze() | |
| engine.default_setup(cfg, args) | |
| # Setup logger for "mask_former" module | |
| setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="opdformer") | |
| return cfg | |
| def format_input(rgb_path: str) -> list[dict[str, Any]]: | |
| """ | |
| Read and format input image into detectron2 form so that it can be passed to the model. | |
| :param rgb_path: path to RGB image file | |
| :return: list of dictionaries per image, where each dictionary is of the form | |
| { | |
| "file_name": path to RGB image, | |
| "image": torch.Tensor of dimensions [channel, height, width] representing the image | |
| } | |
| """ | |
| image = imageio.imread(rgb_path).astype(np.float32) | |
| image_tensor = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) # dim: [channel, height, width] | |
| return [{"file_name": rgb_path, "image": image_tensor}] | |
| def load_model(model: nn.Module, checkpoint: Any) -> None: | |
| """ | |
| Load weights from a checkpoint. | |
| The majority of the function definition is taken from the DetectionCheckpointer implementation provided in | |
| detectron2. While not all of this code is necessarily needed for model loading, it was ported with the intention | |
| of keeping the implementation and output as close to the original as possible, and reusing the checkpoint class here | |
| in isolation was determined to be infeasible. | |
| :param model: model for which to load weights | |
| :param checkpoint: checkpoint contains the weights. | |
| """ | |
| def _strip_prefix_if_present(state_dict: dict[str, Any], prefix: str) -> None: | |
| """If prefix is found on all keys in state dict, remove prefix.""" | |
| keys = sorted(state_dict.keys()) | |
| if not all(len(key) == 0 or key.startswith(prefix) for key in keys): | |
| return | |
| for key in keys: | |
| newkey = key[len(prefix) :] | |
| state_dict[newkey] = state_dict.pop(key) | |
| checkpoint_state_dict = checkpoint.pop("model") | |
| # convert from numpy to tensor | |
| for k, v in checkpoint_state_dict.items(): | |
| if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor): | |
| raise ValueError("Unsupported type found in checkpoint! {}: {}".format(k, type(v))) | |
| if not isinstance(v, torch.Tensor): | |
| checkpoint_state_dict[k] = torch.from_numpy(v) | |
| # if the state_dict comes from a model that was wrapped in a | |
| # DataParallel or DistributedDataParallel during serialization, | |
| # remove the "module" prefix before performing the matching. | |
| _strip_prefix_if_present(checkpoint_state_dict, "module.") | |
| # workaround https://github.com/pytorch/pytorch/issues/24139 | |
| model_state_dict = model.state_dict() | |
| incorrect_shapes = [] | |
| for k in list(checkpoint_state_dict.keys()): # state dict is modified in loop, so list op is necessary | |
| if k in model_state_dict: | |
| model_param = model_state_dict[k] | |
| # Allow mismatch for uninitialized parameters | |
| if TORCH_VERSION >= (1, 8) and isinstance(model_param, nn.parameter.UninitializedParameter): | |
| continue | |
| shape_model = tuple(model_param.shape) | |
| shape_checkpoint = tuple(checkpoint_state_dict[k].shape) | |
| if shape_model != shape_checkpoint: | |
| has_observer_base_classes = ( | |
| TORCH_VERSION >= (1, 8) | |
| and hasattr(quantization, "ObserverBase") | |
| and hasattr(quantization, "FakeQuantizeBase") | |
| ) | |
| if has_observer_base_classes: | |
| # Handle the special case of quantization per channel observers, | |
| # where buffer shape mismatches are expected. | |
| def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: | |
| # foo.bar.param_or_buffer_name -> [foo, bar] | |
| key_parts = key.split(".")[:-1] | |
| cur_module = model | |
| for key_part in key_parts: | |
| cur_module = getattr(cur_module, key_part) | |
| return cur_module | |
| cls_to_skip = ( | |
| ObserverBase, | |
| FakeQuantizeBase, | |
| ) | |
| target_module = _get_module_for_key(model, k) | |
| if isinstance(target_module, cls_to_skip): | |
| # Do not remove modules with expected shape mismatches | |
| # them from the state_dict loading. They have special logic | |
| # in _load_from_state_dict to handle the mismatches. | |
| continue | |
| incorrect_shapes.append((k, shape_checkpoint, shape_model)) | |
| checkpoint_state_dict.pop(k) | |
| model.load_state_dict(checkpoint_state_dict, strict=False) | |
| def predict(model: nn.Module, inp: list[dict[str, Any]]) -> list[dict[str, instances.Instances]]: | |
| """ | |
| Compute model predictions. | |
| :param model: model to run on input | |
| :param inp: input, in the form | |
| { | |
| "image_file": path to image, | |
| "image": float32 torch.tensor of dimensions [channel, height, width] as RGB/RGBD/depth image | |
| } | |
| :return: list of detected instances and predicted openable parameters | |
| """ | |
| with torch.no_grad(), evaluation.inference_context(model): | |
| out = model(inp) | |
| return out | |
| def main( | |
| cfg: CfgNode, | |
| rgb_image: str, | |
| depth_image: str, | |
| intrinsics: list[float], | |
| num_samples: int, | |
| crop: bool, | |
| score_threshold: float, | |
| ) -> None: | |
| """ | |
| Main inference method. | |
| :param cfg: configuration object | |
| :param rgb_image: local path to RGB image | |
| :param depth_image: local path to depth image | |
| :param intrinsics: camera intrinsics matrix as a list of 9 values | |
| :param num_samples: number of sample visualization states to generate | |
| :param crop: if True, images will be cropped to remove whitespace before visualization | |
| :param score_threshold: float between 0 and 1 representing threshold at which to filter instances based on score | |
| """ | |
| logger = logging.getLogger("detectron2") | |
| # setup data | |
| logger.info("Loading image.") | |
| inp = format_input(rgb_image) | |
| # setup model | |
| logger.info("Loading model.") | |
| model = build_model(cfg) | |
| weights = torch.load(cfg.MODEL.WEIGHTS, map_location=torch.device("cpu")) | |
| if "model" not in weights: | |
| weights = {"model": weights} | |
| load_model(model, weights) | |
| # run model on data | |
| logger.info("Running model.") | |
| prediction = predict(model, inp)[0] # index 0 since there is only one image | |
| pred_instances = prediction["instances"] | |
| # log results | |
| image_id = os.path.splitext(os.path.basename(rgb_image))[0] | |
| pred_dict = {"image_id": image_id} | |
| instances = pred_instances.to(torch.device("cpu")) | |
| pred_dict["instances"] = prediction_to_json(instances, image_id) | |
| torch.save(pred_dict, os.path.join(cfg.OUTPUT_DIR, f"{image_id}_prediction.pth")) | |
| # select best prediction to visualize | |
| score_ranking = np.argsort([-1 * pred_instances[i].scores.item() for i in range(len(pred_instances))]) | |
| score_ranking = [idx for idx in score_ranking if pred_instances[int(idx)].scores.item() > score_threshold] | |
| if len(score_ranking) == 0: | |
| logging.warning("The model did not predict any moving parts above the score threshold.") | |
| return | |
| for idx in score_ranking: # iterate through all best predictions, by score threshold | |
| pred = pred_instances[int(idx)] # take highest predicted one | |
| logger.info("Rendering prediction for instance %d", int(idx)) | |
| output_dir = os.path.join(cfg.OUTPUT_DIR, str(idx)) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # extract predicted values for visualization | |
| mask = np.squeeze(pred.pred_masks.cpu().numpy()) # dim: [height, width] | |
| origin = pred.morigin.cpu().numpy().flatten() # dim: [3, ] | |
| axis_vector = pred.maxis.cpu().numpy().flatten() # dim: [3, ] | |
| pred_type = TYPE_CLASSIFICATION.get(pred.mtype.item()) | |
| range_min = 0 - pred.mstate.cpu().numpy() | |
| range_max = pred.mstatemax.cpu().numpy() - pred.mstate.cpu().numpy() | |
| # process visualization | |
| color = o3d.io.read_image(rgb_image) | |
| depth = o3d.io.read_image(depth_image) | |
| rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(color, depth, convert_rgb_to_intensity=False) | |
| color_np = np.asarray(color) | |
| height, width = color_np.shape[:2] | |
| # generate intrinsics | |
| intrinsic_matrix = np.reshape(intrinsics, (3, 3), order="F") | |
| intrinsic_obj = o3d.camera.PinholeCameraIntrinsic( | |
| width, | |
| height, | |
| intrinsic_matrix[0, 0], | |
| intrinsic_matrix[1, 1], | |
| intrinsic_matrix[0, 2], | |
| intrinsic_matrix[1, 2], | |
| ) | |
| # Convert the RGBD image to a point cloud | |
| pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic_obj) | |
| # Create a LineSet to visualize the direction vector | |
| axis_arrow = draw_line(origin, axis_vector + origin) | |
| axis_arrow.paint_uniform_color(ARROW_COLOR) | |
| # if USE_GT: | |
| # anno_path = f"/localhome/atw7/projects/opdmulti/data/data_demo_dev/59-4860.json" | |
| # part_id = 32 | |
| # # get annotation for the frame | |
| # import json | |
| # with open(anno_path, "r") as f: | |
| # anno = json.load(f) | |
| # articulations = anno["articulation"] | |
| # for articulation in articulations: | |
| # if articulation["partId"] == part_id: | |
| # range_min = articulation["rangeMin"] - articulation["state"] | |
| # range_max = articulation["rangeMax"] - articulation["state"] | |
| # break | |
| if pred_type == "rotation": | |
| generate_rotation_visualization( | |
| pcd, | |
| axis_arrow, | |
| mask, | |
| axis_vector, | |
| origin, | |
| range_min, | |
| range_max, | |
| num_samples, | |
| output_dir, | |
| ) | |
| elif pred_type == "translation": | |
| generate_translation_visualization( | |
| pcd, | |
| axis_arrow, | |
| mask, | |
| axis_vector, | |
| range_min, | |
| range_max, | |
| num_samples, | |
| output_dir, | |
| ) | |
| else: | |
| raise ValueError(f"Invalid motion prediction type: {pred_type}") | |
| if pred_type: | |
| if crop: # crop images to remove shared extraneous whitespace | |
| output_dir_cropped = f"{output_dir}_cropped" | |
| if not os.path.isdir(output_dir_cropped): | |
| os.makedirs(output_dir_cropped) | |
| batch_trim(output_dir, output_dir_cropped, identical=True) | |
| # create_gif(output_dir_cropped, num_samples) | |
| else: # leave original dimensions of image as-is | |
| # create_gif(output_dir, num_samples) | |
| pass | |
| if __name__ == "__main__": | |
| # parse arguments | |
| start_time = time.time() | |
| args = get_parser().parse_args() | |
| cfg = setup_cfg(args) | |
| # run main code | |
| engine.launch( | |
| main, | |
| args.num_processes, | |
| args=( | |
| cfg, | |
| args.rgb_image, | |
| args.depth_image, | |
| args.intrinsics, | |
| args.num_samples, | |
| args.crop, | |
| args.score_threshold, | |
| ), | |
| ) | |
| end_time = time.time() | |
| print(f"Inference time: {end_time - start_time:.2f} seconds") | |