import matplotlib.pyplot as plt import matplotlib.patches as patches import torch import numpy as np def plot_bev_detections( boxes_3d: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, score_thresh: float = 0.1, save_path: str = None ): class_names = [ 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'pedestrian', 'motorcycle', 'bicycle', 'traffic_cone', 'car' ] # 1) Create figure & axes fig, ax = plt.subplots(figsize=(12, 12)) # 2) Draw ego vehicle at origin ax.add_patch(patches.Rectangle( (-1, -1), 2, 2, linewidth=1, edgecolor='black', facecolor='gray', label='Ego Vehicle' )) # 3) Filter by score mask = scores >= score_thresh boxes, scores, labels = boxes_3d[mask], scores[mask], labels[mask] # 4) Prepare a color for each class cmap = plt.get_cmap('tab10') # up to 10 distinct colors num_classes = len(class_names) colors = {i: cmap(i % 10) for i in range(num_classes)} # 5) Draw each box seen_labels = set() for box, score, label in zip(boxes, scores, labels): if label != 1: x, y, z, dx, dy, dz, yaw, *_ = box.cpu().numpy() cls_idx = int(label) cls_name = class_names[cls_idx] color = colors[cls_idx] # example of stretching length for 'car' if you still want it if cls_name.lower() == 'car': dx *= 1.2 rect = patches.Rectangle( (x - dx/2, y - dy/2), dx, dy, angle=np.degrees(yaw), linewidth=1.5, edgecolor=color, facecolor='none' ) ax.add_patch(rect) # remember we saw this label so we can add it to legend once seen_labels.add(cls_idx) # 6) Legend only for seen classes legend_handles = [] for cls_idx in sorted(seen_labels): legend_handles.append( patches.Patch(color=colors[cls_idx], label=class_names[cls_idx]) ) ax.legend(handles=legend_handles, loc='upper right') # 7) Axes limits and labels ax.set_xlim(-50, 50) ax.set_ylim(-50, 50) ax.set_xlabel('X (meters)') ax.set_ylabel('Y (meters)') ax.set_title('BEV Detections') # 8) Save or show if save_path: fig.savefig(save_path, bbox_inches='tight') plt.close(fig) else: plt.show()