Spaces:
Runtime error
Runtime error
| # Copyright (c) 2017-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the LICENSE file in | |
| # the root directory of this source tree. An additional grant of patent rights | |
| # can be found in the PATENTS file in the same directory. | |
| import logging | |
| import sys | |
| import torch | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Tuple | |
| from fairseq.data import Dictionary | |
| from fairseq.data.codedataset import ExpressiveCodeDataConfig, CodeDataset | |
| from fairseq.dataclass.configs import FairseqDataclass | |
| from fairseq.tasks import register_task | |
| from fairseq.tasks.fairseq_task import FairseqTask | |
| from omegaconf import MISSING, DictConfig | |
| logger = logging.getLogger(__name__) | |
| class UnitDictionary(Dictionary): | |
| """ | |
| A fixed-sized Dictionary that operates on integer-valued tokens | |
| wth a trivial (identity) token <-> id mapping. | |
| Special symbols (bos, eos, ...) have ids above n_units. | |
| """ | |
| def __init__( | |
| self, | |
| *, # begin keyword-only arguments | |
| n_units, | |
| bos="<s>", | |
| pad="<pad>", | |
| eos="</s>", | |
| unk="<unk>", | |
| extra_special_symbols=None, | |
| clip=False, | |
| ): | |
| self.n_units = n_units | |
| self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos | |
| self.clip = clip | |
| self.symbols = [] | |
| self.count = [] | |
| self.indices = {} | |
| for i in range(n_units): | |
| self.add_symbol(str(i)) | |
| self.bos_index = self.add_symbol(bos) | |
| self.pad_index = self.add_symbol(pad) | |
| self.eos_index = self.add_symbol(eos) | |
| self.unk_index = self.add_symbol(unk) | |
| if extra_special_symbols: | |
| for s in extra_special_symbols: | |
| self.add_symbol(s) | |
| self.nspecial = len(self.symbols) | |
| def encode_line(self, line, append_eos=True, prepend_bos=False) -> torch.IntTensor: | |
| words = [int(x) for x in line.split()] | |
| if self.clip: | |
| words = [min(self.n_units - 1, word) for word in words] | |
| if prepend_bos: | |
| words = [self.bos_index] + words | |
| if append_eos: | |
| words.append(self.eos_index) | |
| ids = torch.IntTensor(words) | |
| return ids | |
| class SpeechUnitModelingConfig(FairseqDataclass): | |
| data: str = field(default=MISSING, metadata={"help": "Path to data config.json"}) | |
| max_token_duration: int = field( | |
| default=20, metadata={"help": "all token durations are capped to this value"} | |
| ) | |
| tokens_per_sample: int = field( | |
| default=1024, metadata={"help": "tokens in a sample"} | |
| ) | |
| max_target_positions: int = field( | |
| default=1024, metadata={"help": "max target positions"} | |
| ) | |
| # duration modeling | |
| ignore_duration_input: bool = field( | |
| default=False, metadata={"help": "whether token durations should be zeroed out"} | |
| ) | |
| discrete_duration: bool = field( | |
| default=False, metadata={"help": "treat duration as discrete variable"} | |
| ) | |
| # F0 modeling | |
| ignore_f0_input: bool = field( | |
| default=False, metadata={"help": "whether F0 should be zeroed out"} | |
| ) | |
| discrete_f0: bool = field( | |
| default=False, metadata={"help": "load quantized f0. get bin from config"} | |
| ) | |
| log_f0: bool = field( | |
| default=False, metadata={"help": "whether f0 should be modeled in log space"} | |
| ) | |
| normalize_f0_mean: bool = field( | |
| default=False, metadata={"help": "whether normalize f0 by speaker mean"} | |
| ) | |
| normalize_f0_std: bool = field( | |
| default=False, metadata={"help": "whether normalize f0 by speaker stddev"} | |
| ) | |
| interpolate_f0: bool = field( | |
| default=False, | |
| metadata={"help": "whether interpolate f0 for non-voiced segments"}, | |
| ) | |
| # input/output streams | |
| stream_shifts: str = field( | |
| default="0,0", | |
| metadata={ | |
| "help": ( | |
| "comma-separated integer list denoting right-shift for " | |
| "duration and pitch streams" | |
| ) | |
| }, | |
| ) | |
| class SpeechUnitLanguageModelingTask(FairseqTask): | |
| def __init__(self, cfg: SpeechUnitModelingConfig) -> None: | |
| super().__init__(cfg) | |
| assert not self.cfg.normalize_f0_std or self.cfg.normalize_f0_mean | |
| self.data_config = ExpressiveCodeDataConfig(cfg.data) | |
| self._source_dictionary = self._target_dictionary = UnitDictionary( | |
| n_units=self.data_config.n_units | |
| ) | |
| self._source_duration_dictionary = self._target_duration_dictionary = ( | |
| UnitDictionary(n_units=self.cfg.max_token_duration + 1, clip=True) | |
| if self.cfg.discrete_duration | |
| else None | |
| ) | |
| self._source_f0_dictionary = self._target_f0_dictionary = ( | |
| UnitDictionary(n_units=self.data_config.f0_vq_n_units) | |
| if self.cfg.discrete_f0 | |
| else None | |
| ) | |
| self._channel_names = ["token", "duration", "f0"] | |
| self._channel_sizes = [ | |
| len(self.target_dictionary), | |
| len(self.target_duration_dictionary) if self.cfg.discrete_duration else 1, | |
| len(self.target_f0_dictionary) if self.cfg.discrete_f0 else 1, | |
| ] | |
| def source_dictionary(self) -> Optional[Dictionary]: | |
| return self._source_dictionary | |
| def source_duration_dictionary(self) -> Optional[Dictionary]: | |
| return self._source_duration_dictionary | |
| def source_f0_dictionary(self) -> Optional[Dictionary]: | |
| return self._source_f0_dictionary | |
| def channel_names(self) -> List[str]: | |
| return self._channel_names | |
| def channel_sizes(self) -> List[int]: | |
| return self._channel_sizes | |
| def dictionary(self) -> Optional[Dictionary]: | |
| return self._source_dictionary | |
| def target_dictionary(self) -> Optional[Dictionary]: | |
| return self._target_dictionary | |
| def target_duration_dictionary(self) -> Optional[Dictionary]: | |
| return self._target_duration_dictionary | |
| def target_f0_dictionary(self) -> Optional[Dictionary]: | |
| return self._target_f0_dictionary | |
| def dictionaries(self) -> List[Dictionary]: | |
| return [self._dictionaries[l] for l in self.cfg.labels] | |
| def setup_task( | |
| cls, cfg: SpeechUnitModelingConfig, **kwargs | |
| ) -> "SpeechUnitLanguageModelingTask": | |
| return cls(cfg) | |
| def load_dataset(self, split: str, **kwargs) -> None: | |
| self.datasets[split] = CodeDataset( | |
| manifest=self.data_config.manifests[split], | |
| dictionary=self.source_dictionary, | |
| dur_dictionary=self.source_duration_dictionary, | |
| f0_dictionary=self.source_f0_dictionary, | |
| config=self.data_config, | |
| discrete_dur=self.cfg.discrete_duration, | |
| discrete_f0=self.cfg.discrete_f0, | |
| log_f0=self.cfg.log_f0, | |
| normalize_f0_mean=self.cfg.normalize_f0_mean, | |
| normalize_f0_std=self.cfg.normalize_f0_std, | |
| interpolate_f0=self.cfg.interpolate_f0, | |
| shifts=self.cfg.stream_shifts, | |
| ) | |
| def max_positions(self) -> Tuple[int, int]: | |
| return (sys.maxsize, sys.maxsize) | |
| def build_criterion(self, cfg: DictConfig): | |
| import fairseq.criterions | |
| return fairseq.criterions.build_criterion(cfg, self) | |