Spaces:
Runtime error
Runtime error
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """metrics_test.py: | |
| This file contains tests for the following classes: | |
| • AMTMetrics | |
| """ | |
| import unittest | |
| import warnings | |
| import torch | |
| import numpy as np | |
| from utils.metrics import AMTMetrics | |
| from utils.metrics import compute_track_metrics | |
| class TestAMTMetrics(unittest.TestCase): | |
| def test_individual_attributes(self): | |
| metric = AMTMetrics() | |
| # Test updating the metric using .update() method | |
| metric.onset_f.update(0.5) | |
| # Test updating the metric using __call__() method | |
| metric.onset_f(0.5) | |
| # Test updating the metric with a weight | |
| metric.onset_f(0, weight=1.0) | |
| # Test computing the average value of the metric | |
| computed_value = metric.onset_f.compute() | |
| self.assertAlmostEqual(computed_value, 0.3333333333333333) | |
| # Test resetting the metric | |
| metric.onset_f.reset() | |
| with self.assertWarns(UserWarning): | |
| torch._assert(metric.onset_f.compute(), torch.nan) | |
| # Test bulk_compute | |
| with self.assertWarns(UserWarning): | |
| computed_metrics = metric.bulk_compute() | |
| def test_bulk_update_and_compute(self): | |
| metric = AMTMetrics() | |
| # Test bulk_update with values only | |
| d1 = {'onset_f': 0.5, 'offset_f': 0.5} | |
| metric.bulk_update(d1) | |
| # Test bulk_update with values and weights | |
| d2 = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} | |
| metric.bulk_update(d2) | |
| # Test bulk_compute | |
| computed_metrics = metric.bulk_compute() | |
| # Ensure the 'onset_f' and 'offset_f' keys exist in the computed_metrics dictionary | |
| self.assertIn('onset_f', computed_metrics) | |
| self.assertIn('offset_f', computed_metrics) | |
| # Check the computed values | |
| self.assertAlmostEqual(computed_metrics['onset_f'], 0.5) | |
| self.assertAlmostEqual(computed_metrics['offset_f'], 0.5) | |
| def test_compute_track_metrics_singing(self): | |
| from config.vocabulary import SINGING_SOLO_CLASS, GM_INSTR_CLASS_PLUS | |
| from utils.event2note import note_event2note | |
| ref_notes_dict = np.load('extras/examples/singing_notes.npy', allow_pickle=True).tolist() | |
| ref_note_events_dict = np.load('extras/examples/singing_note_events.npy', allow_pickle=True).tolist() | |
| est_notes, _ = note_event2note(ref_note_events_dict['note_events']) | |
| ref_notes = ref_notes_dict['notes'] | |
| metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in SINGING_SOLO_CLASS.keys()]) | |
| drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes, | |
| ref_notes, | |
| eval_vocab=SINGING_SOLO_CLASS, | |
| eval_drum_vocab=None, | |
| onset_tolerance=0.05) | |
| metric.bulk_update(drum_metric) | |
| metric.bulk_update(non_drum_metric) | |
| metric.bulk_update(instr_metric) | |
| computed_metrics = metric.bulk_compute() | |
| cnt = 0 | |
| for k, v in computed_metrics.items(): | |
| if 'Singing Voice' in k: | |
| self.assertEqual(v, 1.0) | |
| cnt += 1 | |
| self.assertEqual(cnt, 6) | |
| metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in GM_INSTR_CLASS_PLUS.keys()]) | |
| drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes, | |
| ref_notes, | |
| eval_vocab=GM_INSTR_CLASS_PLUS, | |
| eval_drum_vocab=None, | |
| onset_tolerance=0.05) | |
| metric.bulk_update(drum_metric) | |
| metric.bulk_update(non_drum_metric) | |
| metric.bulk_update(instr_metric) | |
| computed_metrics = metric.bulk_compute() | |
| cnt = 0 | |
| for k, v in computed_metrics.items(): | |
| if 'Singing Voice' in k: | |
| self.assertEqual(v, 1.0) | |
| cnt += 1 | |
| self.assertEqual(cnt, 6) | |
| if __name__ == '__main__': | |
| unittest.main() | |