Spaces:
Runtime error
Runtime error
| import unittest | |
| import pytest | |
| from numpy import random | |
| from utils.note_event_dataclasses import NoteEvent, Event, EventRange | |
| from utils.event_codec import FastCodec as Codec | |
| from utils.tokenizer import EventTokenizer | |
| from utils.tokenizer import NoteEventTokenizer | |
| class TestEventTokenizerBase(unittest.TestCase): | |
| def test_encode_and_decode(self): | |
| tokenizer = EventTokenizer() | |
| events = [ | |
| Event('pitch', 64), | |
| Event('velocity', 1), | |
| Event('tie', 0), | |
| Event('program', 10), | |
| Event('drum', 0) | |
| ] | |
| tokens = tokenizer.encode(events) | |
| decoded_events = tokenizer.decode(tokens) | |
| self.assertEqual(events, decoded_events) | |
| def test_unknown_codec_name(self): | |
| with self.assertRaises(ValueError): | |
| EventTokenizer(base_codec='unknown') | |
| def test_unknown_codec_type(self): | |
| with self.assertRaises(TypeError): | |
| EventTokenizer(base_codec=123) | |
| def test_encode_and_decode_with_custom_codec(self): | |
| special_tokens = ['PAD', 'EOS', 'SOS', 'T'] | |
| max_shift_steps = 100 | |
| event_ranges = [ | |
| EventRange('eat', min_value=0, max_value=9), | |
| EventRange('sleep', min_value=0, max_value=9), | |
| EventRange('play', min_value=0, max_value=1) | |
| ] | |
| my_codec = Codec(special_tokens, max_shift_steps, event_ranges) | |
| tokenizer = EventTokenizer(my_codec) | |
| events = [ | |
| Event('eat', 3), | |
| Event('shift', 9), | |
| Event('sleep', 9), | |
| Event('shift', 20), | |
| Event('play', 1) | |
| ] | |
| tokens = tokenizer.encode(events) | |
| # 0~3: special tokens | |
| # 4~103: shift tokens | |
| # 104~112: eat tokens | |
| # 113~121: sleep tokens | |
| # 122~123: play tokens | |
| expected_tokens = [107, 13, 123, 24, 125] | |
| self.assertEqual(tokens, expected_tokens) | |
| decoded_events = tokenizer.decode(tokens) | |
| self.assertEqual(events, decoded_events) | |
| class TestEventTokenizerBaseProcessTime(unittest.TestCase): | |
| def setUp(self) -> None: | |
| self.tokenizer = EventTokenizer('mt3') | |
| self.random_tokens = random.randint(0, 500, size=333) | |
| self.events = [ | |
| Event(type='pitch', value=60), | |
| Event(type='velocity', value=1), | |
| Event(type='program', value=0), | |
| Event(type='shift', value=10), | |
| Event(type='tie', value=0), | |
| Event(type='drum', value=0), | |
| ] * 55 | |
| # 32 ms --> 8 ms | |
| def test_event_tokenizer_encode(self): | |
| for i in range(64): | |
| encoded = self.tokenizer.encode(self.events) | |
| # 40 ms --> 10 ms | |
| def test_event_tokenizer_decode(self): | |
| for i in range(64): | |
| decoded = self.tokenizer.decode(self.random_tokens) | |
| # yapf: disable | |
| class NoteEventTokenizerTest(unittest.TestCase): | |
| def test_note_event_tokenizer_encode(self): | |
| tokenizer = NoteEventTokenizer() | |
| note_events = [ | |
| NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), | |
| NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()), | |
| NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set()) | |
| ] | |
| tokens = tokenizer.encode(note_events) | |
| decoded_events, decoded_tie_events, last_activity, err_cnt = tokenizer.decode(tokens) | |
| self.assertSequenceEqual(note_events, decoded_events) | |
| self.assertSequenceEqual([], decoded_tie_events) | |
| self.assertEqual(len(last_activity), 0) | |
| self.assertEqual(len(err_cnt), 0) | |
| def test_note_event_tokenizer_encode_plus(self): | |
| tokenizer = NoteEventTokenizer() | |
| note_events = [ | |
| NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), | |
| NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()), | |
| NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set()) | |
| ] | |
| tokens = tokenizer.encode_plus(note_events, max_length=30) | |
| decoded_events, decoded_tie_events, last_activity, err_cnt = tokenizer.decode(tokens) | |
| self.assertSequenceEqual(note_events, decoded_events) | |
| self.assertSequenceEqual([], decoded_tie_events) | |
| self.assertEqual(len(last_activity), 0) | |
| self.assertEqual(len(err_cnt), 0) | |
| # yapf: enable | |
| if __name__ == '__main__': | |
| unittest.main() |