Spaces:
Sleeping
Sleeping
| from collections import defaultdict | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib import cm | |
| import torch | |
| def draw_panoptic_segmentation(model,segmentation, segments_info): | |
| # get the used color map | |
| viridis = cm.get_cmap('viridis', torch.max(segmentation)) | |
| fig, ax = plt.subplots() | |
| ax.imshow(segmentation.cpu().numpy()) | |
| instances_counter = defaultdict(int) | |
| handles = [] | |
| # for each segment, draw its legend | |
| for segment in segments_info: | |
| segment_id = segment['id'] | |
| segment_label_id = segment['label_id'] | |
| segment_label = model.config.id2label[segment_label_id] | |
| label = f"{segment_label}-{instances_counter[segment_label_id]}" | |
| instances_counter[segment_label_id] += 1 | |
| color = viridis(segment_id) | |
| handles.append(mpatches.Patch(color=color, label=label)) | |
| # ax.legend(handles=handles) | |
| fig.savefig('final_mask.png') | |
| return 'final_mask.png' |