Spaces:
Runtime error
Runtime error
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| import torch | |
| import numpy as np | |
| def plot_bev_detections( | |
| boxes_3d: torch.Tensor, | |
| scores: torch.Tensor, | |
| labels: torch.Tensor, | |
| score_thresh: float = 0.1, | |
| save_path: str = None | |
| ): | |
| class_names = [ | |
| 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'pedestrian', | |
| 'motorcycle', 'bicycle', 'traffic_cone', 'car' | |
| ] | |
| # 1) Create figure & axes | |
| fig, ax = plt.subplots(figsize=(12, 12)) | |
| # 2) Draw ego vehicle at origin | |
| ax.add_patch(patches.Rectangle( | |
| (-1, -1), 2, 2, | |
| linewidth=1, | |
| edgecolor='black', | |
| facecolor='gray', | |
| label='Ego Vehicle' | |
| )) | |
| # 3) Filter by score | |
| mask = scores >= score_thresh | |
| boxes, scores, labels = boxes_3d[mask], scores[mask], labels[mask] | |
| # 4) Prepare a color for each class | |
| cmap = plt.get_cmap('tab10') # up to 10 distinct colors | |
| num_classes = len(class_names) | |
| colors = {i: cmap(i % 10) for i in range(num_classes)} | |
| # 5) Draw each box | |
| seen_labels = set() | |
| for box, score, label in zip(boxes, scores, labels): | |
| if label != 1: | |
| x, y, z, dx, dy, dz, yaw, *_ = box.cpu().numpy() | |
| cls_idx = int(label) | |
| cls_name = class_names[cls_idx] | |
| color = colors[cls_idx] | |
| # example of stretching length for 'car' if you still want it | |
| if cls_name.lower() == 'car': | |
| dx *= 1.2 | |
| rect = patches.Rectangle( | |
| (x - dx/2, y - dy/2), | |
| dx, dy, | |
| angle=np.degrees(yaw), | |
| linewidth=1.5, | |
| edgecolor=color, | |
| facecolor='none' | |
| ) | |
| ax.add_patch(rect) | |
| # remember we saw this label so we can add it to legend once | |
| seen_labels.add(cls_idx) | |
| # 6) Legend only for seen classes | |
| legend_handles = [] | |
| for cls_idx in sorted(seen_labels): | |
| legend_handles.append( | |
| patches.Patch(color=colors[cls_idx], label=class_names[cls_idx]) | |
| ) | |
| ax.legend(handles=legend_handles, loc='upper right') | |
| # 7) Axes limits and labels | |
| ax.set_xlim(-50, 50) | |
| ax.set_ylim(-50, 50) | |
| ax.set_xlabel('X (meters)') | |
| ax.set_ylabel('Y (meters)') | |
| ax.set_title('BEV Detections') | |
| # 8) Save or show | |
| if save_path: | |
| fig.savefig(save_path, bbox_inches='tight') | |
| plt.close(fig) | |
| else: | |
| plt.show() | |