Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import unittest | |
| import uuid | |
| from fairseq import metrics | |
| class TestMetrics(unittest.TestCase): | |
| def test_nesting(self): | |
| with metrics.aggregate() as a: | |
| metrics.log_scalar("loss", 1) | |
| with metrics.aggregate() as b: | |
| metrics.log_scalar("loss", 2) | |
| self.assertEqual(a.get_smoothed_values()["loss"], 1.5) | |
| self.assertEqual(b.get_smoothed_values()["loss"], 2) | |
| def test_new_root(self): | |
| with metrics.aggregate() as a: | |
| metrics.log_scalar("loss", 1) | |
| with metrics.aggregate(new_root=True) as b: | |
| metrics.log_scalar("loss", 2) | |
| self.assertEqual(a.get_smoothed_values()["loss"], 1) | |
| self.assertEqual(b.get_smoothed_values()["loss"], 2) | |
| def test_nested_new_root(self): | |
| with metrics.aggregate() as layer1: | |
| metrics.log_scalar("loss", 1) | |
| with metrics.aggregate(new_root=True) as layer2: | |
| metrics.log_scalar("loss", 2) | |
| with metrics.aggregate() as layer3: | |
| metrics.log_scalar("loss", 3) | |
| with metrics.aggregate(new_root=True) as layer4: | |
| metrics.log_scalar("loss", 4) | |
| metrics.log_scalar("loss", 1.5) | |
| self.assertEqual(layer4.get_smoothed_values()["loss"], 4) | |
| self.assertEqual(layer3.get_smoothed_values()["loss"], 3) | |
| self.assertEqual(layer2.get_smoothed_values()["loss"], 2.5) | |
| self.assertEqual(layer1.get_smoothed_values()["loss"], 1.25) | |
| def test_named(self): | |
| name = str(uuid.uuid4()) | |
| metrics.reset_meters(name) | |
| with metrics.aggregate(name): | |
| metrics.log_scalar("loss", 1) | |
| metrics.log_scalar("loss", 3) | |
| with metrics.aggregate(name): | |
| metrics.log_scalar("loss", 2) | |
| self.assertEqual(metrics.get_smoothed_values(name)["loss"], 1.5) | |
| def test_nested_duplicate_names(self): | |
| name = str(uuid.uuid4()) | |
| metrics.reset_meters(name) | |
| with metrics.aggregate(name): | |
| metrics.log_scalar("loss", 1) | |
| with metrics.aggregate() as other: | |
| with metrics.aggregate(name): | |
| metrics.log_scalar("loss", 2) | |
| metrics.log_scalar("loss", 6) | |
| self.assertEqual(metrics.get_smoothed_values(name)["loss"], 3) | |
| self.assertEqual(other.get_smoothed_values()["loss"], 2) | |
| if __name__ == "__main__": | |
| unittest.main() | |