Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tests for panoptic_quality metrics.""" | |
| import collections | |
| from absl import logging | |
| import numpy as np | |
| import tensorflow as tf | |
| from deeplab2.evaluation import panoptic_quality | |
| from deeplab2.evaluation import test_utils | |
| # See the definition of the color names at: | |
| # https://en.wikipedia.org/wiki/Web_colors. | |
| _CLASS_COLOR_MAP = { | |
| (0, 0, 0): 0, | |
| (0, 0, 255): 1, # Person (blue). | |
| (255, 0, 0): 2, # Bear (red). | |
| (0, 255, 0): 3, # Tree (lime). | |
| (255, 0, 255): 4, # Bird (fuchsia). | |
| (0, 255, 255): 5, # Sky (aqua). | |
| (255, 255, 0): 6, # Cat (yellow). | |
| } | |
| def combine_maps(semantic_map, instance_map, label_divisor): | |
| combined_map = instance_map + semantic_map * label_divisor | |
| return tf.cast(combined_map, tf.int32) | |
| class PanopticQualityMetricTest(tf.test.TestCase): | |
| def test_streaming_metric_on_single_image(self): | |
| max_instances_per_category = 1000 | |
| instance_class_map = { | |
| 0: 0, | |
| 47: 1, | |
| 97: 1, | |
| 133: 1, | |
| 150: 1, | |
| 174: 1, | |
| 198: 2, | |
| 215: 1, | |
| 244: 1, | |
| 255: 1, | |
| } | |
| gt_instances, gt_classes = test_utils.panoptic_segmentation_with_class_map( | |
| 'team_gt_instance.png', instance_class_map) | |
| pred_classes = test_utils.read_segmentation_with_rgb_color_map( | |
| 'team_pred_class.png', _CLASS_COLOR_MAP) | |
| pred_instances = test_utils.read_test_image( | |
| 'team_pred_instance.png', image_format='L') | |
| pq_obj = panoptic_quality.PanopticQuality( | |
| num_classes=3, | |
| max_instances_per_category=max_instances_per_category, | |
| ignored_label=0, offset=256*256) | |
| y_true = combine_maps(gt_classes, gt_instances, max_instances_per_category) | |
| y_pred = combine_maps(pred_classes, pred_instances, | |
| max_instances_per_category) | |
| pq_obj.update_state(y_true, y_pred) | |
| result = pq_obj.result().numpy() | |
| self.assertAlmostEqual(result[0], 0.62156284, places=4) | |
| self.assertAlmostEqual(result[1], 0.64664984, places=4) | |
| self.assertAlmostEqual(result[2], 0.9666667, places=4) | |
| self.assertEqual(result[3], 4.) | |
| self.assertAlmostEqual(result[4], 0.5) | |
| self.assertEqual(result[5], 0.) | |
| def test_streaming_metric_on_multiple_images(self): | |
| num_classes = 7 | |
| bird_gt_instance_class_map = { | |
| 92: 5, | |
| 176: 3, | |
| 255: 4, | |
| } | |
| cat_gt_instance_class_map = { | |
| 0: 0, | |
| 255: 6, | |
| } | |
| team_gt_instance_class_map = { | |
| 0: 0, | |
| 47: 1, | |
| 97: 1, | |
| 133: 1, | |
| 150: 1, | |
| 174: 1, | |
| 198: 2, | |
| 215: 1, | |
| 244: 1, | |
| 255: 1, | |
| } | |
| max_instances_per_category = 256 | |
| test_image = collections.namedtuple( | |
| 'TestImage', | |
| ['gt_class_map', 'gt_path', 'pred_inst_path', 'pred_class_path']) | |
| test_images = [ | |
| test_image(bird_gt_instance_class_map, 'bird_gt.png', | |
| 'bird_pred_instance.png', 'bird_pred_class.png'), | |
| test_image(cat_gt_instance_class_map, 'cat_gt.png', | |
| 'cat_pred_instance.png', 'cat_pred_class.png'), | |
| test_image(team_gt_instance_class_map, 'team_gt_instance.png', | |
| 'team_pred_instance.png', 'team_pred_class.png'), | |
| ] | |
| gt_classes = [] | |
| gt_instances = [] | |
| pred_classes = [] | |
| pred_instances = [] | |
| for test_image in test_images: | |
| (image_gt_instances, | |
| image_gt_classes) = test_utils.panoptic_segmentation_with_class_map( | |
| test_image.gt_path, test_image.gt_class_map) | |
| gt_classes.append(image_gt_classes) | |
| gt_instances.append(image_gt_instances) | |
| pred_classes.append( | |
| test_utils.read_segmentation_with_rgb_color_map( | |
| test_image.pred_class_path, _CLASS_COLOR_MAP)) | |
| pred_instances.append( | |
| test_utils.read_test_image(test_image.pred_inst_path, | |
| image_format='L')) | |
| pq_obj = panoptic_quality.PanopticQuality( | |
| num_classes=num_classes, | |
| max_instances_per_category=max_instances_per_category, | |
| ignored_label=0, offset=256*256) | |
| for pred_class, pred_instance, gt_class, gt_instance in zip( | |
| pred_classes, pred_instances, gt_classes, gt_instances): | |
| y_true = combine_maps(gt_class, gt_instance, max_instances_per_category) | |
| y_pred = combine_maps(pred_class, pred_instance, | |
| max_instances_per_category) | |
| pq_obj.update_state(y_true, y_pred) | |
| result = pq_obj.result().numpy() | |
| self.assertAlmostEqual(result[0], 0.76855499, places=4) | |
| self.assertAlmostEqual(result[1], 0.7769174, places=4) | |
| self.assertAlmostEqual(result[2], 0.98888892, places=4) | |
| self.assertEqual(result[3], 2.) | |
| self.assertAlmostEqual(result[4], 1. / 6, places=4) | |
| self.assertEqual(result[5], 0.) | |
| def test_predicted_non_contiguous_ignore_label(self): | |
| max_instances_per_category = 256 | |
| pq_obj = panoptic_quality.PanopticQuality( | |
| num_classes=3, | |
| max_instances_per_category=max_instances_per_category, | |
| ignored_label=9, | |
| offset=256 * 256) | |
| gt_class = [ | |
| [0, 9, 9], | |
| [1, 2, 2], | |
| [1, 9, 9], | |
| ] | |
| gt_instance = [ | |
| [0, 2, 2], | |
| [1, 0, 0], | |
| [1, 0, 0], | |
| ] | |
| y_true = combine_maps( | |
| np.array(gt_class), np.array(gt_instance), max_instances_per_category) | |
| logging.info('y_true=\n%s', y_true) | |
| pred_class = [ | |
| [0, 0, 9], | |
| [1, 1, 1], | |
| [1, 9, 9], | |
| ] | |
| pred_instance = [ | |
| [0, 0, 0], | |
| [0, 1, 1], | |
| [0, 1, 1], | |
| ] | |
| y_pred = combine_maps( | |
| np.array(pred_class), np.array(pred_instance), | |
| max_instances_per_category) | |
| logging.info('y_pred=\n%s', y_pred) | |
| pq_obj.update_state(y_true, y_pred) | |
| result = pq_obj.result().numpy() | |
| # pq | |
| self.assertAlmostEqual(result[0], 2. / 9, places=4) | |
| # sq | |
| self.assertAlmostEqual(result[1], 1. / 3, places=4) | |
| # rq | |
| self.assertAlmostEqual(result[2], 2. / 9, places=4) | |
| # tp | |
| self.assertAlmostEqual(result[3], 1. / 3, places=4) | |
| # fn | |
| self.assertAlmostEqual(result[4], 2. / 3, places=4) | |
| # fp | |
| self.assertAlmostEqual(result[5], 2. / 3, places=4) | |
| if __name__ == '__main__': | |
| tf.test.main() | |