yaghi27's picture
Update model/utils.py
3d641e6
raw
history blame
2.5 kB
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import numpy as np
def plot_bev_detections(
boxes_3d: torch.Tensor,
scores: torch.Tensor,
labels: torch.Tensor,
score_thresh: float = 0.1,
save_path: str = None
):
class_names = [
'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'pedestrian',
'motorcycle', 'bicycle', 'traffic_cone', 'car'
]
# 1) Create figure & axes
fig, ax = plt.subplots(figsize=(12, 12))
# 2) Draw ego vehicle at origin
ax.add_patch(patches.Rectangle(
(-1, -1), 2, 2,
linewidth=1,
edgecolor='black',
facecolor='gray',
label='Ego Vehicle'
))
# 3) Filter by score
mask = scores >= score_thresh
boxes, scores, labels = boxes_3d[mask], scores[mask], labels[mask]
# 4) Prepare a color for each class
cmap = plt.get_cmap('tab10') # up to 10 distinct colors
num_classes = len(class_names)
colors = {i: cmap(i % 10) for i in range(num_classes)}
# 5) Draw each box
seen_labels = set()
for box, score, label in zip(boxes, scores, labels):
if label != 1:
x, y, z, dx, dy, dz, yaw, *_ = box.cpu().numpy()
cls_idx = int(label)
cls_name = class_names[cls_idx]
color = colors[cls_idx]
# example of stretching length for 'car' if you still want it
if cls_name.lower() == 'car':
dx *= 1.2
rect = patches.Rectangle(
(x - dx/2, y - dy/2),
dx, dy,
angle=np.degrees(yaw),
linewidth=1.5,
edgecolor=color,
facecolor='none'
)
ax.add_patch(rect)
# remember we saw this label so we can add it to legend once
seen_labels.add(cls_idx)
# 6) Legend only for seen classes
legend_handles = []
for cls_idx in sorted(seen_labels):
legend_handles.append(
patches.Patch(color=colors[cls_idx], label=class_names[cls_idx])
)
ax.legend(handles=legend_handles, loc='upper right')
# 7) Axes limits and labels
ax.set_xlim(-50, 50)
ax.set_ylim(-50, 50)
ax.set_xlabel('X (meters)')
ax.set_ylabel('Y (meters)')
ax.set_title('BEV Detections')
# 8) Save or show
if save_path:
fig.savefig(save_path, bbox_inches='tight')
plt.close(fig)
else:
plt.show()