Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """ tokenizer.py: Encodes and decodes events to/from tokens. """ | |
| import numpy as np | |
| import warnings | |
| from abc import ABC, abstractmethod | |
| from utils.note_event_dataclasses import Event, EventRange, Note #, Codec | |
| from utils.event_codec import FastCodec as Codec | |
| from utils.note_event_dataclasses import NoteEvent | |
| from utils.note2event import note_event2event | |
| from utils.event2note import event2note_event, note_event2note | |
| from typing import List, Optional, Union, Tuple, Dict, Counter | |
| #TODO: Too complex to be an abstract class. | |
| class EventTokenizerBase(ABC): | |
| """ | |
| A base class for encoding and decoding events to and from tokens. | |
| """ | |
| def __init__( | |
| self, | |
| base_codec: Union[Codec, str] = 'mt3', | |
| special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], | |
| extra_tokens: List[str] = [], | |
| max_shift_steps: int = 206, # 1001 in Gardner et al. | |
| program_vocabulary: Optional[Dict] = None, | |
| drum_vocabulary: Optional[Dict] = None, | |
| ) -> None: | |
| """ | |
| Initializes the EventTokenizerBase object. | |
| :param base_codec: The codec to use for encoding and decoding. | |
| :param special_tokens: None or list of special tokens to include in the vocabulary. | |
| :param extra_tokens: None or list of tokens to be treated as additional special tokens. | |
| :param program_vocabulary: None or a dictionary mapping program names to program indices. | |
| :param drum_vocabulary: None or a dictionary mapping drum names to drum indices. | |
| :param max_shift_steps: The maximum number of shift steps to use for the codec. | |
| """ | |
| # Initialize the codec attribute based on the input codec parameter. | |
| if isinstance(base_codec, str): | |
| # If codec is a string, initialize codec with the appropriate Codec object. | |
| if base_codec.lower() == 'mt3': | |
| event_ranges = [ | |
| EventRange('pitch', min_value=0, max_value=127), | |
| EventRange('velocity', min_value=0, max_value=1), | |
| EventRange('tie', min_value=0, max_value=0), | |
| EventRange('program', min_value=0, max_value=127), | |
| EventRange('drum', min_value=0, max_value=127), | |
| ] | |
| else: | |
| raise ValueError(f'Unknown codec name: {base_codec}') | |
| # Initialize codec | |
| self.codec = Codec(special_tokens=special_tokens + extra_tokens, | |
| max_shift_steps=max_shift_steps, | |
| event_ranges=event_ranges, | |
| program_vocabulary=program_vocabulary, | |
| drum_vocabulary=drum_vocabulary, | |
| name='mt3') | |
| elif isinstance(base_codec, Codec): | |
| # If codec is a Codec object, store it directly. | |
| self.codec = base_codec | |
| if program_vocabulary is not None or drum_vocabulary is not None: | |
| print('') | |
| warnings.warn("Vocabulary cannot be applied when using a custom codec.") | |
| else: | |
| # If codec is neither a string nor a Codec object, raise a NotImplementedError. | |
| raise TypeError(f'Unknown codec type: {type(base_codec)}') | |
| self.num_tokens = self.codec._num_classes | |
| def _encode(self, events: List[Event]) -> List[int]: | |
| return [self.codec.encode_event(e) for e in events] | |
| def _decode(self, tokens: List[int]) -> List[Event]: | |
| return [self.codec.decode_event_index(idx) for idx in tokens] | |
| def encode(self): | |
| """ Encode your custom events to tokens. """ | |
| pass | |
| def decode(self): | |
| """ Decode your custom tokens to events.""" | |
| pass | |
| class EventTokenizer(EventTokenizerBase): | |
| """ | |
| Eencoding and decoding events to and from tokens. | |
| """ | |
| def __init__(self, | |
| base_codec: Union[Codec, str] = 'mt3', | |
| special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], | |
| extra_tokens: List[str] = [], | |
| max_shift_steps: int = 206, | |
| program_vocabulary: Optional[Dict] = None, | |
| drum_vocabulary: Optional[Dict] = None) -> None: | |
| """ | |
| Initializes the EventTokenizerBase object. | |
| :param codec: The codec to use for encoding and decoding. | |
| :param special_tokens: None or list of special tokens to include in the vocabulary. | |
| :param extra_tokens: None or list of tokens to be treated as additional special tokens. | |
| :param program_vocabulary: None or a dictionary mapping program names to program indices. | |
| :param drum_vocabulary: None or a dictionary mapping drum names to drum indices. | |
| :param max_shift_steps: The maximum number of shift steps to use for the codec. | |
| """ | |
| # Initialize the codec attribute based on the input codec parameter. | |
| super().__init__( | |
| base_codec=base_codec, | |
| special_tokens=special_tokens, | |
| extra_tokens=extra_tokens, | |
| max_shift_steps=max_shift_steps, | |
| program_vocabulary=program_vocabulary, | |
| drum_vocabulary=drum_vocabulary, | |
| ) | |
| def encode(self, events): | |
| """ Encode your custom events to tokens. """ | |
| return super()._encode(events) | |
| def decode(self, tokens): | |
| """ Decode your custom tokens to events.""" | |
| return super()._decode(tokens) | |
| class NoteEventTokenizer(EventTokenizerBase): | |
| """ Encodes and decodes note events to/from tokens. """ | |
| def __init__( | |
| self, | |
| base_codec: Union[Codec, str] = 'mt3', | |
| max_length: int = 1024, # max length of tokens | |
| tps: int = 100, | |
| sort_note_event: bool = True, | |
| special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], | |
| extra_tokens: List[str] = [], | |
| max_shift_steps: int = 206, | |
| program_vocabulary: Optional[Dict] = None, | |
| drum_vocabulary: Optional[Dict] = None, | |
| ignore_decoding_tokens: List[str] = [], | |
| ignore_decoding_tokens_from_and_to: Optional[List[str]] = None, | |
| debug_mode: bool = False) -> None: | |
| """ | |
| Initializes the TaskEventNoteTokenizer object. | |
| List[NoteEvent] -> encdoe_note_events -> np.ndarray[int] | |
| np.ndarray[int] -> decode_note_events -> Tuple[List[NoteEvent], List[NoteEvent]] | |
| :param codec: The codec to use for encoding and decoding. | |
| :param special_tokens: None or list of special tokens to include in the vocabulary. | |
| :param extra_tokens: None or list of tokens to be treated as additional special tokens. | |
| :param program_vocabulary: None or a dictionary mapping program names to program indices. | |
| :param drum_vocabulary: None or a dictionary mapping drum names to drum indices. | |
| :param max_shift_steps: The maximum number of shift steps to use for the codec. | |
| :param ignore_decoding_tokens: List of tokens to ignore during decoding. | |
| :param ignore_decoding_tokens_from_and_to: List of tokens to ignore during decoding. [from, to] | |
| """ | |
| super().__init__(base_codec=base_codec, | |
| special_tokens=special_tokens, | |
| extra_tokens=extra_tokens, | |
| max_shift_steps=max_shift_steps, | |
| program_vocabulary=program_vocabulary, | |
| drum_vocabulary=drum_vocabulary) | |
| self.max_length = max_length | |
| self.tps = tps | |
| self.sort = sort_note_event | |
| # Prepare prefix, suffix and pad tokens. | |
| self._prefix = [] | |
| self._suffix = [] | |
| for stk in self.codec.special_tokens: | |
| if stk == 'EOS': | |
| self._suffix.append(self.codec.special_tokens.index('EOS')) | |
| elif stk == 'PAD': | |
| self._zero_pad = [0] * 1024 | |
| elif stk == 'UNK': | |
| pass | |
| else: | |
| pass | |
| # raise NotImplementedError(f'Unknown special token: {stk}') | |
| self.eos_id = self.codec.special_tokens.index('EOS') | |
| self.pad_id = self.codec.special_tokens.index('PAD') | |
| self.ids_to_ignore_decoding = [self.codec.special_tokens.index(t) for t in ignore_decoding_tokens] | |
| self.ignore_tokens_from_and_to = ignore_decoding_tokens_from_and_to | |
| self.debug_mode = debug_mode | |
| def _decode(self, tokens): | |
| # This is event detokenizer, not note_event. It is required for displaying events in validation dashboard | |
| return super()._decode(tokens) | |
| def encode( | |
| self, | |
| note_events: List[NoteEvent], | |
| tie_note_events: Optional[List[NoteEvent]] = None, | |
| start_time: float = 0., | |
| ) -> List[int]: | |
| """ Encodes note events and tie note events to tokens. """ | |
| events = note_event2event( | |
| note_events=note_events, | |
| tie_note_events=tie_note_events, | |
| start_time=start_time, # required for calcuating relative time | |
| tps=self.tps, | |
| sort=self.sort) | |
| return super()._encode(events) | |
| def encode_plus( | |
| self, | |
| note_events: List[NoteEvent], | |
| tie_note_events: Optional[List[NoteEvent]] = None, | |
| start_times: float = 0., # Fixing bug: start_time --> start_times | |
| add_special_tokens: Optional[bool] = True, | |
| max_length: Optional[int] = None, # if None, use self.max_length | |
| pad_to_max_length: Optional[bool] = True, | |
| return_attention_mask: bool = False) -> Union[List[int], Tuple[List[int], List[int]]]: | |
| """ Encodes note events and tie note info to padded tokens. """ | |
| encoded = self.encode(note_events, tie_note_events, start_times) | |
| # if task_events: | |
| # encoded = super()._encode(task_events) + encoded | |
| if add_special_tokens: | |
| if self._prefix: | |
| encoded = self._prefix + encoded | |
| if self._suffix: | |
| encoded = encoded + self._suffix | |
| if max_length is None: | |
| max_length = self.max_length | |
| length = len(encoded) | |
| if length >= max_length: | |
| encoded = encoded[:max_length] | |
| length = max_length | |
| if return_attention_mask: | |
| attention_mask = [1] * length | |
| # <PAD> | |
| if pad_to_max_length is True: | |
| if len(self._zero_pad) != max_length: | |
| self._zero_pad = [self.pad_id] * max_length | |
| if return_attention_mask: | |
| attention_mask += self._zero_pad[length:] | |
| encoded = encoded + self._zero_pad[length:] | |
| if return_attention_mask: | |
| return encoded, attention_mask | |
| return encoded | |
| def encode_task(self, task_events: List[Event], max_length: Optional[int] = None) -> List[int]: | |
| # NOTE: This is an event tokenizer that generates task ids, not the list of note_event objects. | |
| encoded = super()._encode(task_events) | |
| # <PAD> | |
| if max_length is not None: | |
| if len(self._zero_pad_task) != max_length: | |
| self._zero_pad_task = [self.pad_id] * max_length | |
| length = len(encoded) | |
| encoded = encoded + self._zero_pad[length:] | |
| return encoded | |
| def decode( | |
| self, | |
| tokens: List[int], | |
| start_time: float = 0., | |
| return_events: bool = False, | |
| ) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], | |
| List[Event], int]]: | |
| """Decodes a sequence of tokens into note events. | |
| Args: | |
| tokens (List[int]): The list of tokens to be decoded. | |
| start_time (float, optional): The starting time for the note events. Defaults to 0. | |
| return_events (bool, optional): Indicates whether to include the raw events in the return value. | |
| Defaults to False. | |
| Returns: | |
| Union[Tuple[List[NoteEvent], List[NoteEvent]], | |
| Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events. | |
| If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`, | |
| `last_activity`, and `err_cnt`. | |
| If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`, | |
| `last_activity`, `events`, and `err_cnt`. | |
| """ | |
| if self.debug_mode: | |
| ignored_tokens_from_input = [t for t in tokens if t in self.ids_to_ignore_decoding] | |
| print(ignored_tokens_from_input) | |
| if self.ids_to_ignore_decoding: | |
| tokens = [t for t in tokens if t not in self.ids_to_ignore_decoding] | |
| events = super()._decode(tokens) | |
| note_events, tie_note_events, last_activity, err_cnt = event2note_event(events, start_time, True, self.tps) | |
| if return_events: | |
| return note_events, tie_note_events, last_activity, events, err_cnt | |
| else: | |
| return note_events, tie_note_events, last_activity, err_cnt | |
| def decode_batch( | |
| self, | |
| batch_tokens: Union[List[List[int]], np.ndarray], | |
| start_times: List[float], | |
| return_events: bool = False | |
| ) -> Union[Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], int], | |
| Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], List[List[Event]], | |
| Counter[str]]]: | |
| """ | |
| Decodes a batch of tokens to note_events and tie_note_events. | |
| Args: | |
| batch_tokens (List[List[int]] or np.ndarray): Tokens to be decoded. | |
| start_times (List[float]): List of start times for each token set. | |
| return_events (bool, optional): Flag to determine if events should be returned. Defaults to False. | |
| """ | |
| if isinstance(batch_tokens, np.ndarray): | |
| batch_tokens = batch_tokens.tolist() | |
| if len(batch_tokens) != len(start_times): | |
| raise ValueError('The length of batch_tokens and start_times must be same.') | |
| zipped_note_events_and_tie = [] | |
| list_events = [] | |
| total_err_cnt = 0 | |
| for tokens, start_time in zip(batch_tokens, start_times): | |
| if return_events: | |
| note_events, tie_note_events, last_activity, events, err_cnt = self.decode( | |
| tokens, start_time, return_events) | |
| list_events.append(events) | |
| else: | |
| note_events, tie_note_events, last_activity, err_cnt = self.decode(tokens, start_time, return_events) | |
| zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time)) | |
| total_err_cnt += err_cnt | |
| if return_events: | |
| return zipped_note_events_and_tie, list_events, total_err_cnt | |
| else: | |
| return zipped_note_events_and_tie, total_err_cnt | |
| def decode_list_batches( | |
| self, | |
| list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]], | |
| list_start_times: Union[List[List[float]], List[float]], | |
| return_events: bool = False | |
| ) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], Counter[str]], | |
| Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], | |
| List[List[Event]], Counter[str]]]: | |
| """ | |
| Decodes a list of variable-size batches of token array to a list of | |
| zipped note_events and tie_note_events. | |
| Args: | |
| list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length) | |
| list_start_times: List[float], where the length is sum of all batch_sizes. | |
| return_events: bool, Defaults to False. | |
| Returns: | |
| list_list_zipped_note_events_and_tie: | |
| List[ | |
| Tuple[ | |
| List[NoteEvent]: A list of note events. | |
| List[NoteEvent]: A list of tie note events. | |
| List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful | |
| for validating notes within a batch of segments extracted from a file. | |
| List[float]: A list of segment start times. | |
| ] | |
| ] | |
| (Optional) list_events: | |
| List[List[Event]] | |
| total_err_cnt: | |
| Counter[str]: error counter. | |
| """ | |
| list_tokens = [] | |
| for arr in list_batch_tokens: | |
| for tokens in arr: | |
| list_tokens.append(tokens) | |
| assert (len(list_tokens) == len(list_start_times)) | |
| zipped_note_events_and_tie = [] | |
| list_events = [] | |
| total_err_cnt = Counter() | |
| for tokens, start_time in zip(list_tokens, list_start_times): | |
| note_events, tie_note_events, last_activity, events, err_cnt = self.decode( | |
| tokens, start_time, return_events) | |
| zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time)) | |
| if return_events: | |
| list_events.append(events) | |
| total_err_cnt += err_cnt | |
| if return_events: | |
| return zipped_note_events_and_tie, list_events, total_err_cnt | |
| else: | |
| return zipped_note_events_and_tie, total_err_cnt | |