import matplotlib.pyplot as plt import matplotlib.patches as patches import torch import numpy as np from typing import List, Tuple, Dict, Union import os import re from huggingface_hub import hf_hub_download import mmengine from mmdet3d.apis import init_model, inference_mono_3d_detector DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' MODEL_REGISTRY: Dict[str, Tuple[str, Tuple[str, str]]] = { "regnetx4.0gf+detr3d": ( "model/DETR3D/detr3d_r101_gridmask.py", ("yaghi27/RegnetX4.0GF_DETR3D", "epoch_30.pth"), ), "regnetx4.0gf+petr": ( "model/PETR/petr_vovnet_gridmask_p4_800x320.py", ("yaghi27/RegnetX4.0GF_PETR", "epoch_30.pth"), ), } CAM_ORDER = [ "CAM_FRONT", "CAM_FRONT_LEFT", "CAM_FRONT_RIGHT", "CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT", ] DUMMY_4x4 = np.eye(4, dtype=np.float32) def index_from_name(path: str) -> int: """ Extract 0..5 from filenames like `cam_0.png` written by the backend. Falls back to natural sort order if no index found. """ m = re.search(r'(\d+)(?=\.[^.]+$)', os.path.basename(path)) return int(m.group(1)) if m else 9999 def purge_project_registrations(project_prefix: str): """Remove classes registered by a project (e.g., 'projects.DETR3D' or 'projects.PETR') across common MMEngine registries to avoid name collisions when swapping models.""" registries = [] try: from mmdet.registry import MODELS, TASK_UTILS, HOOKS, DATASETS, TRANSFORMS, METRICS registries += [MODELS, TASK_UTILS, HOOKS, DATASETS, TRANSFORMS, METRICS] except Exception: pass try: from mmdet3d.registry import ( MODELS as M3_MODELS, TASK_UTILS as M3_TASK_UTILS, HOOKS as M3_HOOKS, DATASETS as M3_DATASETS, TRANSFORMS as M3_TRANSFORMS, METRICS as M3_METRICS ) registries += [M3_MODELS, M3_TASK_UTILS, M3_HOOKS, M3_DATASETS, M3_TRANSFORMS, M3_METRICS] except Exception: pass for reg in registries: md = getattr(reg, "_module_dict", None) if not isinstance(md, dict): continue for name, cls in list(md.items()): if getattr(cls, "__module__", "").startswith(project_prefix): md.pop(name, None) def infer_single(model, img_path: str, cam_key: str): """Run single-frame mono-3D inference with a provided camera key.""" data_info = { 'images': { cam_key: { 'img_path': img_path, 'cam2img': DUMMY_4x4, 'lidar2cam': DUMMY_4x4 } }, 'img_shape': (800, 450, 3), # adjust if your inputs differ 'scale_factor': 1.0 } ann_file = img_path + '.pkl' mmengine.dump({'data_list': [data_info]}, ann_file) try: result = inference_mono_3d_detector(model, img_path, ann_file, cam_type=cam_key) finally: try: os.remove(ann_file) except OSError: pass return result def download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str = None) -> str: # default to a writable cache location in Spaces/containers cache_dir = cache_dir or os.environ.get("HF_HUB_CACHE") or "/tmp/hf_cache" os.makedirs(cache_dir, exist_ok=True) ckpt_path = hf_hub_download( repo_id=repo_id, filename=filename, revision="main", cache_dir=cache_dir, token=os.environ.get("HF_TOKEN"), ) if not os.path.isfile(ckpt_path): raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}") return ckpt_path def plot_bev_detections( boxes_3d: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, score_thresh: float = 0.1, save_path: str = None ): class_names = [ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone', ] # 1) Create figure & axes fig, ax = plt.subplots(figsize=(12, 12)) # 2) Draw ego vehicle at origin ax.add_patch(patches.Rectangle( (-1, -1), 2, 2, linewidth=1, edgecolor='black', facecolor='gray', label='Ego Vehicle' )) # 3) Filter by score mask = scores >= score_thresh boxes, scores, labels = boxes_3d[mask], scores[mask], labels[mask] # 4) Prepare a color for each class cmap = plt.get_cmap('tab10') # up to 10 distinct colors num_classes = len(class_names) colors = {i: cmap(i % 10) for i in range(num_classes)} # 5) Draw each box seen_labels = set() for box, score, label in zip(boxes, scores, labels): x, y, z, dx, dy, dz, yaw, *_ = box.cpu().numpy() cls_idx = int(label) cls_name = class_names[cls_idx] color = colors[cls_idx] # example of stretching length for 'car' if you still want it if cls_name.lower() == 'car': dx *= 1.2 rect = patches.Rectangle( (x - dx/2, y - dy/2), dx, dy, angle=np.degrees(yaw), linewidth=1.5, edgecolor=color, facecolor='none' ) ax.add_patch(rect) # remember we saw this label so we can add it to legend once seen_labels.add(cls_idx) # 6) Legend only for seen classes legend_handles = [] for cls_idx in sorted(seen_labels): legend_handles.append( patches.Patch(color=colors[cls_idx], label=class_names[cls_idx]) ) ax.legend(handles=legend_handles, loc='upper right') # 7) Axes limits and labels ax.set_xlim(-50, 50) ax.set_ylim(-50, 50) ax.set_xlabel('X (meters)') ax.set_ylabel('Y (meters)') ax.set_title('BEV Detections') # 8) Save or show if save_path: fig.savefig(save_path, bbox_inches='tight') plt.close(fig) else: plt.show()