import os import re import cv2 import torch import numpy as np import torch.nn as nn import mmengine from mmengine.logging.history_buffer import HistoryBuffer import numpy.core.multiarray as cam import numpy.dtypes as dtypes _orig_torch_load = torch.load torch.load = lambda f, **kwargs: _orig_torch_load(f, weights_only=False, **kwargs) torch.serialization.add_safe_globals([ cam.scalar, HistoryBuffer, cam._reconstruct, np.ndarray, np.dtype, dtypes.Float64DType, dtypes.Int64DType, getattr, ]) from functools import lru_cache from typing import List, Tuple, Dict, Union from huggingface_hub import hf_hub_download from mmengine.config import Config from mmdet3d.apis import init_model, inference_mono_3d_detector from model.utils import * from mmengine.registry import init_default_scope from mmdet3d.utils import register_all_modules register_all_modules(init_default_scope=False) import importlib, sys from pathlib import Path @lru_cache(maxsize=2) def get_model(model_key: str): if model_key not in MODEL_REGISTRY: raise ValueError(f"Unknown model '{model_key}'. Available: {list(MODEL_REGISTRY)}") # Remove symbols registered by the *other* project to avoid duplicate keys if "petr" in model_key: purge_project_registrations("projects.DETR3D") else: # detr3d purge_project_registrations("projects.PETR") # Ensure the mmdetection3d repo root (which contains `projects/`) is importable repo_root = Path(__file__).resolve().parents[1] / "mmdetection3d" if repo_root.is_dir() and str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) # Import the correct project so its registries (e.g., ResizeCropFlipImage) are registered proj_name = "projects.PETR.petr" if "petr" in model_key.lower() else "projects.DETR3D.detr3d" try: importlib.import_module(proj_name) except ModuleNotFoundError as e: # Helpful error that tells you what path we tried raise ModuleNotFoundError( f"Could not import {proj_name}. Ensure the 'projects' package is on sys.path. " f"Tried adding: {repo_root}" ) from e config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key] if not os.path.isfile(config_path): raise FileNotFoundError(f"Config not found: {config_path}") ckpt_path = download_ckpt_from_hf(repo_id, hf_file) # Load + sanitize cfg cfg = Config.fromfile(config_path) cfg.default_scope = "mmdet3d" if isinstance(cfg.model, dict): cfg.model.setdefault("pretrained", None) if isinstance(cfg.model.get("backbone", None), dict): cfg.model["backbone"].setdefault("init_cfg", None) # Ensure the registry is initialized to mmdet3d before building init_default_scope("mmdet3d") # Build from cfg so tweaks apply model = init_model(cfg, ckpt_path, device=DEVICE) model.eval() # Some backbones expect a batch dim if hasattr(model, "img_backbone") and hasattr(model.img_backbone, "forward"): orig = model.img_backbone.forward def _ensure_batch(x): if x.dim() == 3: x = x.unsqueeze(0) return orig(x) model.img_backbone.forward = _ensure_batch return model def infer_images(img_paths: List[str], model: str) -> List[str]: """ img_paths: list of 6 images (front, front_left, front_right, back, back_left, back_right) model: 'regnetx4.0gf+detr3d' or 'regnetx4.0gf+petr' Returns: list with path(s) to rendered BEV image(s) """ model_inst = get_model(model) ordered = sorted(img_paths, key=index_from_name) if len(ordered) != 6: raise ValueError(f"Expected 6 images, got {len(ordered)}") all_boxes, all_scores, all_labels = [], [], [] for idx, img_path in enumerate(ordered): cam_key = CAM_ORDER[idx] res = infer_single(model_inst, img_path, cam_key) # Collect detections all_boxes.append(res.pred_instances_3d.bboxes_3d.tensor) all_scores.append(res.pred_instances_3d.scores_3d) all_labels.append(res.pred_instances_3d.labels_3d) boxes = torch.cat(all_boxes, dim=0) if len(all_boxes) else torch.empty((0, 7)) scores = torch.cat(all_scores, dim=0) if len(all_scores) else torch.empty((0,)) labels = torch.cat(all_labels, dim=0) if len(all_labels) else torch.empty((0,), dtype=torch.long) combined_path = os.path.join("/tmp", "combined_bev.png") if "detr3d" in model.lower(): score_thresh = 0.1 elif "petr" in model.lower(): score_thresh = 0.35 else: score_thresh = 0.2 # fallback plot_bev_detections(boxes, scores, labels, save_path=combined_path, score_thresh=score_thresh) return [combined_path] if __name__ == "__main__": example_imgs = [f"/path/to/cam_{i}.png" for i in range(6)] if all(os.path.isfile(p) for p in example_imgs): out = infer_images(example_imgs, model="regnetx4.0gf+petr") print("Saved:", out) else: print("Provide six images named cam_0..cam_5.png to run the demo.")