Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tests for build_cityscapes_data.""" | |
| import os | |
| from absl import flags | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| from deeplab2.data import build_cityscapes_data | |
| FLAGS = flags.FLAGS | |
| _TEST_DATA_DIR = 'deeplab2/data/testdata' | |
| _TEST_FILE_PREFIX = 'dummy_000000_000000' | |
| class BuildCityscapesDataTest(tf.test.TestCase): | |
| def test_read_segments(self): | |
| cityscapes_root = os.path.join(_TEST_DATA_DIR) | |
| segments_dict = build_cityscapes_data._read_segments( | |
| cityscapes_root, dataset_split='dummy') | |
| self.assertIn(_TEST_FILE_PREFIX, segments_dict) | |
| _, segments = segments_dict[_TEST_FILE_PREFIX] | |
| self.assertLen(segments, 10) | |
| def test_generate_panoptic_label(self): | |
| FLAGS.treat_crowd_as_ignore = False # Test a more complicated setting | |
| cityscapes_root = os.path.join(_TEST_DATA_DIR) | |
| segments_dict = build_cityscapes_data._read_segments( | |
| cityscapes_root, dataset_split='dummy') | |
| annotation_file_name, segments = segments_dict[_TEST_FILE_PREFIX] | |
| panoptic_annotation_file = build_cityscapes_data._get_panoptic_annotation( | |
| cityscapes_root, dataset_split='dummy', | |
| annotation_file_name=annotation_file_name) | |
| panoptic_label = build_cityscapes_data._generate_panoptic_label( | |
| panoptic_annotation_file, segments) | |
| # Check panoptic label matches golden file. | |
| golden_file_path = os.path.join(_TEST_DATA_DIR, | |
| 'dummy_gt_for_vps.png') | |
| with tf.io.gfile.GFile(golden_file_path, 'rb') as f: | |
| golden_label = Image.open(f) | |
| # The PNG file is encoded by: | |
| # color = [segmentId % 256, segmentId // 256, segmentId // 256 // 256] | |
| golden_label = np.dot(np.asarray(golden_label), [1, 256, 256 * 256]) | |
| np.testing.assert_array_equal(panoptic_label, golden_label) | |
| if __name__ == '__main__': | |
| tf.test.main() | |